diff --git a/README.md b/README.md index 0496ac8..bdb8c73 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ where `NAME` need to be replaced by * `rc` for the RC model (Ex 3). * `rcr` for the RCR model (Ex 4). * `adaann` for the Friedman model example (Ex 5). -* `rcr_nofas_adaann_example` for the RCR model, combining NoFAS with adaptive annealing (AdaAnn) +* `rcr_nofas_adaann` for the RCR model, combining NoFAS with adaptive annealing (AdaAnn) At regular intervals set by the parameter `experiment.save_interval` LINFA writes a few results files. The sub-string `NAME` refers to the experiment name specified in the `experiment.name` variable, and `IT` indicates the iteration at which the file is written. The results files are @@ -73,7 +73,7 @@ At regular intervals set by the parameter `experiment.save_interval` LINFA write A post processing script is also available to plot all results. To run it type ```sh -python linfa.plot_res -n NAME -i IT -f FOLDER +python -m linfa.plot_res -n NAME -i IT -f FOLDER ``` where `NAME` and `IT` are again the experiment name and iteration number corresponding to the result file of interest, while `FOLDER` is the name of the folder with the results of the inference task are kept. diff --git a/docs/content/imgs/rcr/data_plot_rcr_25000_0_1-1.png b/docs/content/imgs/rcr/data_plot_rcr_25000_0_1-1.png deleted file mode 100644 index 0f45b5c..0000000 Binary files a/docs/content/imgs/rcr/data_plot_rcr_25000_0_1-1.png and /dev/null differ diff --git a/docs/content/imgs/rcr/data_plot_rcr_25000_0_2-1.png b/docs/content/imgs/rcr/data_plot_rcr_25000_0_2-1.png deleted file mode 100644 index b7e7733..0000000 Binary files a/docs/content/imgs/rcr/data_plot_rcr_25000_0_2-1.png and /dev/null differ diff --git a/docs/content/imgs/rcr/data_plot_rcr_25000_1_2-1.png b/docs/content/imgs/rcr/data_plot_rcr_25000_1_2-1.png deleted file mode 100644 index 52aeea9..0000000 Binary files a/docs/content/imgs/rcr/data_plot_rcr_25000_1_2-1.png and /dev/null differ diff --git a/docs/content/imgs/rcr/data_plot_rcr_nofas_adaann_8400_0_1-1.png b/docs/content/imgs/rcr/data_plot_rcr_nofas_adaann_8400_0_1-1.png new file mode 100644 index 0000000..18b0d18 Binary files /dev/null and b/docs/content/imgs/rcr/data_plot_rcr_nofas_adaann_8400_0_1-1.png differ diff --git a/docs/content/imgs/rcr/data_plot_rcr_nofas_adaann_8400_0_2-1.png b/docs/content/imgs/rcr/data_plot_rcr_nofas_adaann_8400_0_2-1.png new file mode 100644 index 0000000..327a31e Binary files /dev/null and b/docs/content/imgs/rcr/data_plot_rcr_nofas_adaann_8400_0_2-1.png differ diff --git a/docs/content/imgs/rcr/data_plot_rcr_nofas_adaann_8400_1_2-1.png b/docs/content/imgs/rcr/data_plot_rcr_nofas_adaann_8400_1_2-1.png new file mode 100644 index 0000000..a6e50c3 Binary files /dev/null and b/docs/content/imgs/rcr/data_plot_rcr_nofas_adaann_8400_1_2-1.png differ diff --git a/docs/content/imgs/rcr/log_plot-1.png b/docs/content/imgs/rcr/log_plot-1.png new file mode 100644 index 0000000..7886549 Binary files /dev/null and b/docs/content/imgs/rcr/log_plot-1.png differ diff --git a/docs/content/imgs/rcr/log_plot_rcr-1.png b/docs/content/imgs/rcr/log_plot_rcr-1.png deleted file mode 100644 index cdf3dc2..0000000 Binary files a/docs/content/imgs/rcr/log_plot_rcr-1.png and /dev/null differ diff --git a/docs/content/imgs/rcr/params_plot_rcr_25000_0_1-1.png b/docs/content/imgs/rcr/params_plot_rcr_25000_0_1-1.png deleted file mode 100644 index 4341b5e..0000000 Binary files a/docs/content/imgs/rcr/params_plot_rcr_25000_0_1-1.png and /dev/null differ diff --git a/docs/content/imgs/rcr/params_plot_rcr_25000_0_2-1.png b/docs/content/imgs/rcr/params_plot_rcr_25000_0_2-1.png deleted file mode 100644 index 7d466f4..0000000 Binary files a/docs/content/imgs/rcr/params_plot_rcr_25000_0_2-1.png and /dev/null differ diff --git a/docs/content/imgs/rcr/params_plot_rcr_25000_1_2-1.png b/docs/content/imgs/rcr/params_plot_rcr_25000_1_2-1.png deleted file mode 100644 index b65ebef..0000000 Binary files a/docs/content/imgs/rcr/params_plot_rcr_25000_1_2-1.png and /dev/null differ diff --git a/docs/content/imgs/rcr/params_plot_rcr_nofas_adaann_8400_0_1-1.png b/docs/content/imgs/rcr/params_plot_rcr_nofas_adaann_8400_0_1-1.png new file mode 100644 index 0000000..87022ad Binary files /dev/null and b/docs/content/imgs/rcr/params_plot_rcr_nofas_adaann_8400_0_1-1.png differ diff --git a/docs/content/imgs/rcr/params_plot_rcr_nofas_adaann_8400_0_2-1.png b/docs/content/imgs/rcr/params_plot_rcr_nofas_adaann_8400_0_2-1.png new file mode 100644 index 0000000..266cda5 Binary files /dev/null and b/docs/content/imgs/rcr/params_plot_rcr_nofas_adaann_8400_0_2-1.png differ diff --git a/docs/content/imgs/rcr/params_plot_rcr_nofas_adaann_8400_1_2-1.png b/docs/content/imgs/rcr/params_plot_rcr_nofas_adaann_8400_1_2-1.png new file mode 100644 index 0000000..4fe4df0 Binary files /dev/null and b/docs/content/imgs/rcr/params_plot_rcr_nofas_adaann_8400_1_2-1.png differ diff --git a/docs/content/rcr.rst b/docs/content/rcr.rst index f1d51bb..d4f3214 100644 --- a/docs/content/rcr.rst +++ b/docs/content/rcr.rst @@ -1,7 +1,7 @@ Three-element Wndkessel Model ============================= -The three-parameter Windkessel or **RCR** model is characterized by proximal and distal resistance parameters :math:`R_{p}, R_{d} \in [100, 1500]`$ Barye :math:`\cdot` s/ml and one capacitance parameter :math:`C \in [1\times 10^{-5}, 1\times 10^{-2}]` ml/Barye. +The three-parameter Windkessel or **RCR** model is characterized by proximal and distal resistance parameters :math:`R_{p}, R_{d} \in [100, 1500]` Barye :math:`\cdot` s/ml and one capacitance parameter :math:`C \in [1\times 10^{-5}, 1\times 10^{-2}]` ml/Barye. This model is not identifiable. The average distal pressure is only affected by the total system resistance, i.e. the sum :math:`R_{p}+R_{d}`, leading to a negative correlation between these two parameters. Thus, an increment in the proximal resistance is compensated by a reduction in the distal resistance (so the average distal pressure remains the same) which, in turn, reduces the friction encountered by the flow exiting the capacitor. An increase in the value of :math:`C` is finally needed to restore the average, minimum and maximum pressure. This leads to a positive correlation between :math:`C` and :math:`R_{d}`. @@ -14,14 +14,14 @@ where the distal pressure is set to :math:`P_{d}=55` mmHg. Synthetic observations are generated from :math:`N(\boldsymbol\mu, \boldsymbol\Sigma)`, where :math:`\mu=(f_{1}(\boldsymbol{z}^{*}),f_{2}(\boldsymbol{z}^{*}),f_{3}(\boldsymbol{z}^{*}))^T = (P_{p,\text{min}}, P_{p,\text{max}}, P_{p,\text{ave}})^T = (100.96, 148.02,116.50)^T` and :math:`\boldsymbol\Sigma`` is a diagonal matrix with entries :math:`(5.05, 7.40, 5.83)^T`. The budgeted number of true model solutions is 216; the fixed surrogate model is evaluated on a :math:`6\times 6\times 6 = 216` pre-grid while the adaptive surrogate is evaluated with a pre-grid of size :math:`4\times 4\times 4 = 64` and the other 152 evaluations are adaptively selected. -The results are presented in :numref:`fig_rcr_res`. The posterior samples obtained through NoFAS capture well the non-linear correlation among the parameters and generate a fairly accurate posterior predictive distribution that overlaps with the observations. Additional details can be found in :cite:p:`wang2022variational`. +This example also demonstrates how NoFAS can be combined with annealing for improved convergence. The results in :numref:`fig_rcr_res` are generated using the AdaAnn adaptive annealing scheduler with intial inverse temperature :math:`t_{0}=0.05`, KL tolerance :math:`\tau=0.01` and a batch size of 100 samples. The number of parameter updates is set to 500, 5000 and 5 for :math:`t_{0}`, :math:`t_{1}` and :math:`t_{0}\n", - "\n", - "* Under our model **Phys** class, we define three functions: \n", - " * `__init__`: The constructor function. Class variables are defined.\n", - " * `genDataFile`: The data file generator function. \n", - " * `solve_t`: The solver function.\n", - " * *check comments for detailed documentation.* " - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "#### Implementation of the traditional trajectory motion physics problem ####\n", - "class Phys:\n", - " \n", - " ### Define constructor function for Phys class ###\n", - " def __init__(self):\n", - " ## Define input parameters (True value) \n", - " # input[] = [starting_position, initial_velocity, angle] = [1(m), 5(m/s), 60(degs)]\n", - " self.defParam = torch.Tensor([[1.0, 5.0, 60.0]])\n", - "\n", - " self.gConst = 9.81 # gravitational constant\n", - " self.stdRatio = 0.05 # standard deviation ratio\n", - " self.data = None # data set of model sample\n", - "\n", - " ### Define data file generator function###\n", - " # dataSize (int): size of sample (data)\n", - " # dataFileName (String): name of the sample data file\n", - " # store (Boolean): True if user wish to store the generated data file; False otherwise.\n", - " def genDataFile(self, dataSize = 50, dataFileName=\"data_phys.txt\", store=True):\n", - " def_out = self.solve_t(self.defParam)[0]\n", - " print(def_out)\n", - " self.data = def_out + self.stdRatio * torch.abs(def_out) * torch.normal(0, 1, size=(dataSize, 3))\n", - " self.data = self.data.t().detach().numpy()\n", - " if store: np.savetxt(dataFileName, self.data)\n", - " return self.data\n", - "\n", - " ### Define data file generator function###\n", - " # params (Tensor): input parameters storing starting position, initial velocity, and angle in corresponding order.\n", - " def solve_t(self, params):\n", - " z1, z2, z3 = torch.chunk(params, chunks=3, dim=1) # input parameters\n", - " z3 = z3 * (np.pi / 180) # convert unit from degree to radians\n", - " \n", - " ## Output value calculation\n", - " # ouput[] = [maximum_height, final_location, total_time]\n", - " x = torch.cat(( (z2 * z2 * torch.sin(z3) * torch.sin(z3)) / (2.0 * self.gConst), # x1: maxHeight\n", - " z1 + ((z2 * z2 * torch.sin(2.0 * z3)) / self.gConst), # x2: finalLocation \n", - " (2.0 * z2 * torch.sin(z3)) / self.gConst), 1) \n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([0.9557, 3.2070, 0.8828])\n" - ] - } - ], - "source": [ - "## Generate phys sample file ##\n", - "\n", - "# Define model\n", - "model = Phys()\n", - "\n", - "# Generate Data\n", - "physData = model.genDataFile()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2. Check for Gradient Calculation\n", - "* Prior to applying NOFAS to our Phys model, we need to check if the gradient calculated by our model matches the gradient calculated by PyTorch to ensure it's functionality without including the surrogate. \n", - "* Specifically, when surrogate is not enabled, gradient calculation is completed straight through the model so we want to ensure that the model is capable to calculate the correct gradients before applying NOFAS.\n", - "* In this step, we compute each gradient using (1) Pytorch and the (2) Phys model and compare the values to ensure model functionality.\n", - "* We proceed with the following order:\n", - " 1) Gradient Calculation with PyTorch\n", - " 2) Gradient Calculation by Phys Model itself\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - " ##### (1) Gradient calculation with PyTorch ver 1" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "#### Implementation of gradient calculation using PyTorch - version 1 #### \n", - "class PytorchGrad1: \n", - " ### Define constructor function for PytorchGrad1 class ###\n", - " def __init__(self):\n", - " # Define gravitational constant\n", - " self.gConst = 9.81\n", - " \n", - " # Define each input manually so that it enables gradient calculation\n", - " self.z1 = torch.tensor(1.0, requires_grad = True)\n", - " self.z2 = torch.tensor(5.0, requires_grad = True)\n", - " self.z3 = torch.tensor(60.0 * np.pi / 180, requires_grad = True)\n", - " \n", - " # Define each output manually reflecting the inputs above\n", - " self.x1 = (self.z2 ** 2) * (torch.sin(self.z3) ** 2) / (2.0 * self.gConst) \n", - " self.x2 = self.z1 + ((self.z2 ** 2) * torch.sin(2.0 * self.z3)) / self.gConst \n", - " self.x3 = (2.0 * self.z2 * torch.sin(self.z3)) / self.gConst\n", - "\n", - " ### Compute gradients using backward function ###\n", - " def back_x1(self): \n", - " self.x1.backward()\n", - " dz1 = self.z1.grad\n", - " dz2 = self.z2.grad\n", - " dz3 = self.z3.grad\n", - " return [dz1, dz2, dz3]\n", - " \n", - " def back_x2(self): \n", - " self.x2.backward()\n", - " dz1 = self.z1.grad\n", - " dz2 = self.z2.grad\n", - " dz3 = self.z3.grad\n", - " return [dz1, dz2, dz3]\n", - " \n", - " def back_x3(self): \n", - " self.x3.backward()\n", - " dz1 = self.z1.grad\n", - " dz2 = self.z2.grad\n", - " dz3 = self.z3.grad\n", - " return [dz1, dz2, dz3]\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
dz1dz2dz3
dx1Nonetensor(0.3823)tensor(1.1035)
dx2tensor(1.)tensor(0.8828)tensor(-2.5484)
dx3Nonetensor(0.1766)tensor(0.5097)
\n", - "
" - ], - "text/plain": [ - " dz1 dz2 dz3\n", - "dx1 None tensor(0.3823) tensor(1.1035)\n", - "dx2 tensor(1.) tensor(0.8828) tensor(-2.5484)\n", - "dx3 None tensor(0.1766) tensor(0.5097)" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "## Gradient calculation with PyTorch ver 1\n", - "\n", - "# List to store dx/dz values\n", - "dx_dz_pytorch1 = [] \n", - "\n", - "# Set pytorchGrad1 object\n", - "pyGrad1 = PytorchGrad1()\n", - "# Calculate gradient and add to list\n", - "dx_dz_pytorch1.append(pyGrad1.back_x1())\n", - "\n", - "# Reset pytorchGrad1 object\n", - "pyGrad1 = PytorchGrad1()\n", - "# Calculate gradient and add to list\n", - "dx_dz_pytorch1.append(pyGrad1.back_x2())\n", - "\n", - "# Reset pytorchGrad1 object\n", - "pyGrad1 = PytorchGrad1()\n", - "# Calculate gradient and add to list\n", - "dx_dz_pytorch1.append(pyGrad1.back_x3())\n", - "\n", - "# print(dx_dz_pytorch1[1]) # check if values match as expected\n", - "\n", - "# Convert to pandas-DataFrame for readability\n", - "jacob_mat_1 = pd.DataFrame(dx_dz_pytorch1, columns=['dz1', 'dz2', 'dz3'])\n", - "jacob_mat_1.index = ['dx1', 'dx2', 'dx3']\n", - "jacob_mat_1" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - " ##### (1) Gradient calculation with PyTorch ver 2" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "#### Implementation of gradient calculation using PyTorch - version 2 #### \n", - "class PytorchGrad2: \n", - " ### Define constructor function for PytorchGrad2 class ###\n", - " def __init__(self, model, transform):\n", - " # Define input parameters and enable gradient calculation\n", - " self.z = torch.Tensor([[1.0, 5.0, 60.0]])\n", - " self.z.requires_grad = True\n", - " \n", - " self.in_vals = torch.from_numpy(transform.forward(self.z).detach().numpy())\n", - " #self.in_vals = self.z\n", - " self.in_vals.requires_grad = True \n", - "\n", - " self.out_val = model.solve_t(self.in_vals)\n", - " self.out1, self.out2, self.out3 = torch.chunk(self.out_val, chunks=3, dim=1)\n", - "\n", - " # Compute gradients using backward function for y\n", - " def back_x1(self): \n", - " self.out1.backward()\n", - " d1 = self.in_vals.grad\n", - " a, b, c = torch.chunk(d1, chunks=3, dim=1)\n", - " return [a, b, c]\n", - " \n", - " def back_x2(self): \n", - " self.out2.backward()\n", - " d2 = self.in_vals.grad\n", - " a, b, c = torch.chunk(d2, chunks=3, dim=1)\n", - " return [a, b, c]\n", - " \n", - " def back_x3(self): \n", - " self.out3.backward()\n", - " d3 = self.in_vals.grad\n", - " a, b, c = torch.chunk(d3, chunks=3, dim=1)\n", - " return [a, b, c]\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
dz1dz2dz3
dx1[[tensor(0.)]][[tensor(0.3823)]][[tensor(0.0193)]]
dx2[[tensor(1.)]][[tensor(0.8828)]][[tensor(-0.0445)]]
dx3[[tensor(0.)]][[tensor(0.1766)]][[tensor(0.0089)]]
\n", - "
" - ], - "text/plain": [ - " dz1 dz2 dz3\n", - "dx1 [[tensor(0.)]] [[tensor(0.3823)]] [[tensor(0.0193)]]\n", - "dx2 [[tensor(1.)]] [[tensor(0.8828)]] [[tensor(-0.0445)]]\n", - "dx3 [[tensor(0.)]] [[tensor(0.1766)]] [[tensor(0.0089)]]" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Define Phys model\n", - "model = Phys()\n", - "# Set transformation information and define transforamtion\n", - "trsf_info = [['identity',0.0,0.0,0.0,0.0],\n", - " ['identity',0,0.0,0.0,0.0],\n", - " ['identity',0,0.0,0.0,0.0]]\n", - " # ['linear',-3,3,30.0,80.0]]\n", - " \n", - "transform = Transformation(trsf_info)\n", - "\n", - "# List to store dx/dz values\n", - "dx_dz_pytorch2 = []\n", - "\n", - "# Define PytorchGrad object and calculate gradient\n", - "pyGrad2 = PytorchGrad2(model, transform)\n", - "dx_dz_pytorch2.append(pyGrad2.back_x1())\n", - "\n", - "pyGrad2 = PytorchGrad2(model, transform)\n", - "dx_dz_pytorch2.append(pyGrad2.back_x2())\n", - "\n", - "pyGrad2 = PytorchGrad2(model, transform)\n", - "dx_dz_pytorch2.append(pyGrad2.back_x3())\n", - "\n", - "# print(dx_dz_pytorch2) # check if output matches expectations\n", - "\n", - "# convert to pandas DataFrame for readability\n", - "jacob_mat_2 = pd.DataFrame(dx_dz_pytorch2, columns=['dz1', 'dz2', 'dz3'])\n", - "jacob_mat_2.index = ['dx1', 'dx2', 'dx3']\n", - "jacob_mat_2" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### (2) Manual gradient calculation with Phys Model" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "### Function that manually calculates a derivative ###\n", - "def getGrad(f_eps, f, eps):\n", - " return (f_eps - f) / (eps)\n", - "\n", - "### Function that returns a list of gradients ###\n", - "def gradList(f_eps1, f_eps2, f_eps3, f, eps): \n", - " return [getGrad(f_eps1, f, eps), getGrad(f_eps2, f, eps), getGrad(f_eps3, f, eps)]" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
dz1dz2dz3
dx1tensor(0.)tensor(0.5161)tensor(0.0185)
dx2tensor(1.)tensor(1.1918)tensor(-0.0491)
dx3tensor(0.)tensor(0.1766)tensor(0.0084)
\n", - "
" - ], - "text/plain": [ - " dz1 dz2 dz3\n", - "dx1 tensor(0.) tensor(0.5161) tensor(0.0185)\n", - "dx2 tensor(1.) tensor(1.1918) tensor(-0.0491)\n", - "dx3 tensor(0.) tensor(0.1766) tensor(0.0084)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# List to store dx/dz values\n", - "dx_dz = []\n", - "dx1_dz = []\n", - "dx2_dz = []\n", - "dx3_dz = []\n", - "\n", - "# Set up parameters\n", - "eps = 3.5 \n", - "z = torch.Tensor([[1.0, 5.0, 60.0]])\n", - "z_eps1 = torch.Tensor([[1.0 + eps, 5.0, 60.0]])\n", - "z_eps2 = torch.Tensor([[1.0, 5.0 + eps, 60.0]])\n", - "z_eps3 = torch.Tensor([[1.0, 5.0, 60.0 + eps]])\n", - "\n", - "x1_eps1 = model.solve_t(z_eps1)[0,0]\n", - "x1_eps2 = model.solve_t(z_eps2)[0,0]\n", - "x1_eps3 = model.solve_t(z_eps3)[0,0]\n", - "x1_eps = model.solve_t(z)[0,0]\n", - "\n", - "dx1_dz = gradList(x1_eps1, x1_eps2, x1_eps3, x1_eps, eps)\n", - "dx_dz.append(dx1_dz)\n", - "\n", - "x2_eps1 = model.solve_t(z_eps1)[0,1]\n", - "x2_eps2 = model.solve_t(z_eps2)[0,1]\n", - "x2_eps3 = model.solve_t(z_eps3)[0,1]\n", - "x2_eps = model.solve_t(z)[0,1]\n", - "\n", - "dx2_dz = gradList(x2_eps1, x2_eps2, x2_eps3, x2_eps, eps)\n", - "dx_dz.append(dx2_dz)\n", - "\n", - "x3_eps1 = model.solve_t(z_eps1)[0,2]\n", - "x3_eps2 = model.solve_t(z_eps2)[0,2]\n", - "x3_eps3 = model.solve_t(z_eps3)[0,2]\n", - "x3_eps = model.solve_t(z)[0,2]\n", - "\n", - "dx3_dz = gradList(x3_eps1, x3_eps2, x3_eps3, x3_eps, eps)\n", - "dx_dz.append(dx3_dz)\n", - "\n", - "# print(dx_dz) # check if values match expected outputs\n", - "\n", - "# convert to pandas DataFrame for readability\n", - "jacob_mat_3 = pd.DataFrame(dx_dz, columns=['dz1', 'dz2', 'dz3'])\n", - "jacob_mat_3.index = ['dx1', 'dx2', 'dx3']\n", - "jacob_mat_3" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### We focus on dx2_dz3 to check if it converges to the Pytorch gradient value\n", - "- Note: adjust values to check convergence for other gradients of interest" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "## Focus: dx2_dz3\n", - "\n", - "initial_eps = 15 # Initial change of value (eps)\n", - "k = 150 # Number of iterations\n", - "dx2_dz3_list = [] # List to store results\n", - "pytorch_grad1 = -2.5484 # Pytorch gradient value (first version)\n", - "pytorch_grad2 = -0.0445 # Pytorch gradient value (second version)\n", - "\n", - "# Calculate for dx2_dz3 as eps decreases\n", - "for t in range(1, k):\n", - " update_eps = initial_eps*(1/t) # updated eps value\n", - " z_eps3 = torch.Tensor([[1.0, 5.0, 60.0 + update_eps]]) # update z_eps3\n", - " x2_eps3 = model.solve_t(z_eps3)[0,1] # update x2_eps3\n", - " dx2_dz3_list.append(getGrad(x2_eps3, x2_eps, update_eps)) # store result to dx2_dz3_list" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ0AAAEWCAYAAAC9qEq5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deZwU1b3//9ebAQREI6tBEEFxCSpBmChuCSoiRoOanxsxBheucYlb4jeSqwa3JMQYzTW5iRdX4i4mRjRRBJRLxHUQxF0UEUYQEISrIvvn90fVjD0zPTPNMN09wvv5eNSjuk6dOnW6ZqY/c06drqOIwMzMrBCaFbsCZma25XDQMTOzgnHQMTOzgnHQMTOzgnHQMTOzgnHQMTOzgnHQMTOzgnHQsS2OpLmSBqWv/1PSrU2gTqdJeqYRyztO0nxJn0nap7HKzSj/Skl3N3a5dZyvh6SQ1LxQ57T8cNCxJkXSyZJekPS5pMXp63MlKR/ni4hfR8SITS0nlw/F9IN6bRoIlkt6VtL+DTjXFEn11fl64CcR0TYiZmzsORqLpM6S7pO0QNIKSdMk7ZfH850s6e30XIsljZW0bb7OZxvPQceaDEk/A/4L+B3wdWB74GzgQKBlLceUFKyCjeOBiGgLdAKeAf6ep4C6E/B6Qw5s5GvaFngJ6A+0B8YC/5TUthHPkWkacGBEfA3YGWgOXJunc1kDOOhYkyDpa8DVwLkR8VBEfBqJGRFxSkSsTvPdKekvkv4l6XPgEElHSZoh6f/SLqUrq5V9qqQPJC2VdFm1fVW6iSQNSFsgyyW9Imlgxr4pkq5J/1v/VNKTkjqmu6em6+VpS6bOFkxErCX5AP460CHL9ThA0kvpf+wvSTogTf8VcDDwp/Q8f6p23FaSPgNKgFckvZemfyOt/3JJr0samnFMjWuapT49Jf1v+r4nAh0z9p0kaU5Fi0LSkZI+ktQpIuZExA0RsTAi1kfEGJJ/IHav6/pIKpF0vaSPJc0BjsrYt3/63iuWVZLmptd1fkR8nFHUeqBXXeeyAosIL16KvgBDgHVA83ry3QmsIGn9NANaAQOBvdPtPsAi4Ng0f2/gM+DbwFbADel5BqX7rwTuTl93BZYC303LOjzd7pTunwK8B+wGtE63R6f7egBRV/2rnWsrkhbd/HT7NOCZ9HV74BPgVJL/1Iel2x0y6jGinusUQK/0dQvgXeA/ST7wDwU+BXav7ZpmKe+59NptlV7LTyveS7r/nrScDsAC4Oha6tUXWAV8rZ76nw28BeyYXo+ns13f9L1NAX6TkXZQ+n4C+BwYXOzfby9fLm7pWFPREfg4ItZVJGS0OL6Q9O2MvI9ExLSI2BARqyJiSkS8mm7PAu4DvpPmPR54LCKmRtJaugLYUEsdfgj8KyL+lZY1ESgjCUIV7oiIdyLiC+BBkg/RjXGipOXAfJIup2Oz5DkKmB0Rd0XEuoi4j+QD+Hsbea4KA0i6uUZHxJqIeAp4jCSYVahyTTMPltQd+BZwRUSsjoipwKPVznEeSTCbAjwaEY9Vr0TaEroLuCoiVtRT5xOBP0TSclkG/KaWfDeRBJbKFmxEPBNJ91o3ksA+t55zWQE56FhTsRTomHkjPiIOiIjt0n2Zv6vzMw+UtJ+kpyUtkbSC5L/kiu6fHTLzR8TnaXnZ7ASckAa65WlwOAjokpHno4zXK0k+zDfGgxGxXUR0johDI2J6ljw7AB9US/uApCXWEDuQtKgyg2318uZTux2AT9Jrl3l8pYhYDowD9gJ+X70ASa1JAtXzEVFbAKlR59rOl5b5Y5JW7g+qvbeKOn0IPAHcn8P5rEAcdKypeA5YDRyTQ97q83HcC4wHdkz/w70ZqLg5v5CkiwYASW3Icg8lNR+4Kw0KFcvWETG6AXXaFAtIAmCm7sCHDTzXAmBHSZl/75nl1VfmQqCdpK2rHV9JUl/gDJJW5k3V9m0F/CM9349zrHOVn1uW8x0MXAMcU0+rqTmwS47ntAJw0LEmIf1P+Srgz5KOl9RWUrP0w2zreg7fBlgWEask7Qv8IGPfQ8DRkg6S1JJksEJtv/d3A9+TdER6I7uVpIGSuuXwFpaQdNvtnEPe+vwL2E3SDyQ1l3QSyb2pii6rRRt5nhdIuqB+LqlFOjjie+TYAoiID0i6Ga+S1FLSQWR09UlqRXLt/hM4Hegq6dx0XwuSn8EXwI+ytUhq8SBwgaRuktoBIzPOtyPwQFreO5kHSTpFUncldgJ+BUzO8ZxWAA461mRExHXAT4GfA4tJPlz/B7gUeLaOQ88Frpb0KfBLkg+sijJfJ7nfcC/Jf8+fAOW1nH8+SUvrP0mCyHzg/5HD30lErCT5gJuWds0NqO+YOspaChwN/IykK/DnJDfmK0Zl/RdwvKRPJN1USzGZ5a0BhgJHAh8Dfyb5wH5rI6r1A2A/YBkwCvhrxr7fAOUR8Zf0vtkPgWsl7QockL6XwXw5su+ztKVSl1uACcArwMvA3zP2HUYy6u+hjPIqhof3Jvld+Yxk+PTbwH9sxPu0PFOEZw41M7PCcEvHzMwKxkHHzIpC0s3VvuRZsdxc7LpZ/rh7zczMCsZPbK1Dx44do0ePHsWuhpnZV8r06dM/johO2fY56NShR48elJWVFbsaZmZfKZJqfJm3gu/pmJlZwTjomJlZwTjomJlZwTjomJlZwTjomJlZwTjomJlZwXjIdB3efhsGDix2LczMNh8OOma2BYmMVXx11pl1rlH/2l5v4nFbbQVddqj9UjaQg04ddt8dpkwpdi3MahEB69Yly9q1yVLxOltaba+zpa1fX3VZt65mWl3LxuRvSNkbNjRs8WO/cjdgAEx5rkGHSrXvK1rQkdSeZCKmHiRzmJ8YEZ9kyTccuDzdvDYixlbbPx7YOSL2SrevJJk/Y0ma5T8j4l/pvl8AZwLrgQsiYkLjvivb7EQkH8JffAErVybrzGXVKli9GtasSdaZr6uvc02rWNcXHNatK+61KSnJvjRvXvu+uvK3aAGtWuWWt1mzL9eb4yJVfV2xnW2dj315VMyWzkhgckSMljQy3b40M0MamEYBpSRtvumSxlcEJ0nfJ5msqbobI+L6amX1Bk4G9iSZf32SpN0iYn0jvy8rtA0b4LPPkuXTT3NbZwsitW1vyHWyyzq0aJF0V7RsmawzX2emtWnzZVqLFsnSvHnurxsrb32BpJnHIFnDFDPoHAMMTF+PBaZQLegARwATI2IZgKSJwBDgPkltSWaZPIuMmSLrOd/96cyG70t6F9gXaFj70RpPRBIMli2DTz6puSxfnj39//4vOW7lytzP1bo1tG2bfLi3bp0sbdokS4cOVdMqXmcu1dNbtao9iGSu8/zfo9lXRTGDzvYRsRAgIhZK6pwlT1eSKYMrlKdpANcAvweyfeL8RNKPSOZ1/1naMuoKPF9LWdbYIpLAUF4OixYly+LFyZLt9erVtZdVUgLbbQft2n259OgBX/sabLNNsrRtW/N19bS2bZOyzKxo8hp0JE0imcu8ustyLSJLWkjqC/SKiIsl9ai2/y8kASn4MjCdUVtZWep8Fknrie7du+dYzS3QunXwwQcwZw7Mnw/z5iXrzNfZWiAtWkDnzrD99sm6d+8vX3foUDO4tGuXBAu3FMw2C3kNOhExqLZ9khZJ6pK2croAi7NkK+fLLjiAbiTdcPsD/SXNJXkPnSVNiYiBEbEo4xy3AI9llLVjtbIWZKnzGGAMQGlpqYe6LF8Os2bBG2/A7NnwzjvJes6c5IZ2pq9/Hbp3h732giOPhB13TJavfz0JKp07J0HFAcRsi1XM7rXxwHBgdLp+JEueCcCvJbVLtwcDv0jv8fwFIG3pPBYRA9PtLhXddsBxwGsZ57tX0g0kAwl2BV5s3Lf0FRYB778PL78Mr7zy5TJv3pd5WreGXr1gzz3h2GNht91gl12SQNO1a3LvwsysDsUMOqOBByWdCcwDTgCQVAqcHREjImKZpGuAl9Jjrq4YVFCH69LutyAZiv1jgIh4XdKDwBvAOuC8LXrk2rp18NJL8Mwz8OyzybI4bWyWlCRfUjrwQDj3XOjTJ2m9dO3qUUtmtkkU/rJUrUpLS2Ozmjl07lyYMAGefBImT4YVK5L0Xr3ggAOSpbQ0acm0alXUqprZV5ek6RFRmm2fn0iwuZs7F8aNgwcfhIoAuuOOcMIJcPjhycPlOmcbOGhm1vgcdDZHq1bB/ffDzTfDCy8kad/6Flx3HQwdmtyL8c18MysCB53NSXk5/OUvMGYMfPxxMhx59Gg48UTo2bPYtTMzc9DZLMybB9dcA3fckYxCGzoUzj8fDjnELRoza1IcdL7KPvsMrr0Wbrwx2T73XLj4YrdqzKzJctD5qvrnP+Hss5MutR/9KGnp+AkKZtbE+UsXXzWffw4//jEcfXTyiJhp02DsWAccM/tKcEvnq2TuXDjmGHj1Vfj5z+Hqq5OnGJuZfUU46HxVPPdcMkBg7Vp4/HE44ohi18jMbKO5e+2rYOrU5Iuc220HL77ogGNmX1lu6TR1U6cmT2zu3h2eegq6dCl2jczMGsxBpymbPTt5mnP37jBlSjLvjJnZV5i715qqTz5JRqiVlCTDox1wzGwz4JZOUxQBp52WzG8zeTLsvHOxa2Rm1igcdJqiu++G8ePh+uvh4IOLXRszs0bj7rWmZsECuOCCZG6biy4qdm3MzBqVg05Tc9FFydQEd9yR3M8xM9uMOOg0JWVlyYRrP/95MueNmdlmxkGnKbn8cmjfHn72s2LXxMwsLzyQoKmYOhUmTIDf/Q623bbYtTEzy4uitHQktZc0UdLsdN2ulnzD0zyzJQ3Psn+8pNcyth+QNDNd5kqamab3kPRFxr6b8/fuGujqq5OnDZx3XrFrYmaWN8Vq6YwEJkfEaEkj0+1LMzNIag+MAkqBAKZLGh8Rn6T7vw98lnlMRJyUcfzvgRUZu9+LiL75eDOb7N13k+/jXHsttG5d7NqYmeVNse7pHAOMTV+PBY7NkucIYGJELEsDzURgCICktsBPgWuzFS5JwInAfY1c7/y47bZkpNrppxe7JmZmeVWsoLN9RCwESNeds+TpCszP2C5P0wCuAX4PrKyl/IOBRRExOyOtp6QZkv5XUq3fuJR0lqQySWVLlizJ8e1sgrVrk+HRRx0FO+yQ//OZmRVR3rrXJE0Cvp5l12W5FpElLST1BXpFxMWSetRy7DCqtnIWAt0jYqmk/sA/JO0ZEf9X4wQRY4AxAKWlpZFjXRvuscdg0SIYMSLvpzIzK7a8BZ2IGFTbPkmLJHWJiIWSugCLs2QrBwZmbHcDpgD7A/0lzSWpf2dJUyJiYFp2c+D7QP+MuqwGVqevp0t6D9gNKGvo+2s0t92WtHCOPLLYNTEzy7tida+NBypGow0HHsmSZwIwWFK7dHTbYGBCRPwlInaIiB7AQcA7FQEnNQh4KyLKKxIkdZJUkr7eGdgVmNPI72njrVwJkybBSSdBc49eN7PNX7GCzmjgcEmzgcPTbSSVSroVICKWkdy7eSldrk7T6nMyNQcQfBuYJekV4CHg7BzLyq8pU2D1ardyzGyLoYj837b4qiotLY2ysjz2wJ1/Ptx+OyxdCq1a5e88ZmYFJGl6RJRm2+fH4BTTE0/AIYc44JjZFsNBp1jefTdZhgwpdk3MzArGQadYnngiWft+jpltQRx0iuXxx6FXL9hll2LXxMysYBx0iiECpk2Dww4rdk3MzArKQacY3n8fVqyAfv2KXRMzs4Jy0CmGmTOT9T77FLceZmYF5qBTDDNmJE+V3muvYtfEzKygHHSKYcYM2GMPz51jZlscB51imDHDXWtmtkVy0Cm0xYthwQIHHTPbIjnoFJoHEZjZFsxBp9BmzEjWffsWtx5mZkXgoFNoM2ZAjx7Qrl2xa2JmVnAOOoU2c6ZbOWa2xXLQKaT162HOnGS4tJnZFshBp5AWLYK1a2GnnYpdEzOzonDQKaT585P1jjsWtx5mZkXioFNI8+Yl6+7di1sPM7MiKVrQkdRe0kRJs9N11uFckoaneWZLGp6RPkXS25JmpkvnNH0rSQ9IelfSC5J6ZBzzizT9bUlH5Ps91uCgY2ZbuGK2dEYCkyNiV2Byul2FpPbAKGA/YF9gVLXgdEpE9E2XxWnamcAnEdELuBH4bVpWb+BkYE9gCPBnSSX5eWu1mDcPtt0Wvva1gp7WzKypKGbQOQYYm74eCxybJc8RwMSIWBYRnwATSQJGruU+BBwmSWn6/RGxOiLeB94lCWSFM3++7+eY2RatmEFn+4hYCJCuO2fJ0xWYn7FdnqZVuCPtWrsiDSxVjomIdcAKoEMOZQEg6SxJZZLKlixZ0rB3Vpt589y1ZmZbtLwGHUmTJL2WZTkm1yKypEW6PiUi9gYOTpdT6zmmrrK+TIgYExGlEVHaqVOnHKuZIwcdM9vCNc9n4RExqLZ9khZJ6hIRCyV1ARZnyVYODMzY7gZMScv+MF1/Kulekq6yv6bH7AiUS2oOfA1YlpGeWdaChr2zBvjiC1iyxEHHzLZoxexeGw9UjEYbDjySJc8EYLCkdukAgsHABEnNJXUEkNQCOBp4LUu5xwNPRUSk6Seno9t6ArsCL+bhfWXn7+iYmeW3pVOP0cCDks4E5gEnAEgqBc6OiBERsUzSNcBL6TFXp2lbkwSfFkAJMAm4Jc1zG3CXpHdJWjgnA0TE65IeBN4A1gHnRcT6grxT+DLouKVjZlswJY0Ay6a0tDTKysoap7A77oAzzkievdazZ+OUaWbWBEmaHhGl2fb5iQSFMm8eSNC1xoA5M7MthoNOocybB126QMuWxa6JmVnROOgUir8YambmoFMw/o6OmZmDTsHMnw/duhW7FmZmReWgUwirV8PKldCxY7FrYmZWVA46hbBiRbL206XNbAvnoFMIFUFnu+2KWw8zsyJz0CmE5cuTtVs6ZraFc9ApBHevmZkBDjqF4e41MzPAQacw3L1mZgY46BSGu9fMzIAcg46kE3JJs1osX5487HObbYpdEzOzosq1pfOLHNMsmxUrYNttoZkblma2ZatzEjdJRwLfBbpKuilj17YkE6FZLlas8CACMzPqnzl0AVAGDAWmZ6R/Clycr0ptdpYv9/0cMzPqCToR8QrwiqR7I2Jtgeq0+VmxwkHHzIzc7+nsK2mipHckzZH0vqQ5ea3Z5sTda2ZmQO5B5zbgBuAg4FtAabpuEEnt0yA2O123qyXf8DTPbEnDM9KnSHpb0sx06Zym/1TSG5JmSZosaaeMY9Zn5B/f0Lo3iLvXzMyA+u/pVFgREY834nlHApMjYrSkken2pZkZJLUHRpEEuACmSxofEZ+kWU6JiLJq5c4ASiNipaRzgOuAk9J9X0RE30Z8D7lzS8fMDMi9pfO0pN9J2l9Sv4plE857DDA2fT0WODZLniOAiRGxLA00E4EhdRUaEU9HxMp083mg+LOmRfiejplZKteWzn7pujQjLYBDG3je7SNiIUBELKzoHqumKzA/Y7s8Tatwh6T1wN+AayMiqh1/JpDZOmslqYxkqPfoiPhHtopJOgs4C6B7Y0wv/fnnsH69g46ZGTkGnYg4ZGMLljQJ+HqWXZflWkS2qqTrUyLiQ0nbkASdU4G/Zpz7hyQB8jsZx3aPiAWSdgaekvRqRLxX4wQRY4AxAKWlpdUD2cbzwz7NzCrl+hic7SXdJunxdLu3pDPrOiYiBkXEXlmWR4BFkrqkZXUBFmcpohzYMWO7G8n3hoiID9P1p8C9wL4ZdR1EEtiGRsTqjPpUHDsHmALsk8t732R+2KeZWaVc7+ncCUwAdki33wEu2oTzjgcqRqMNBx7JkmcCMFhSu3R022BggqTmkjoCSGoBHA28lm7vA/wPScCpDGRpGVulrzsCBwJvbEL9c+eWjplZpVyDTseIeBDYABAR64D1m3De0cDhkmYDh6fbSCqVdGt6jmXANcBL6XJ1mrYVSfCZBcwEPgRuScv9HdAWGFdtaPQ3gDJJrwBPk9zTKUzQcUvHzKxSrgMJPpfUgfSeiqQBwIqGnjQilgKHZUkvA0ZkbN8O3F4tz+dA/1rKHVRL+rPA3g2t7ybxtAZmZpVyDTo/JekS20XSNKATcHzearU5cfeamVmlXEevvSzpO8DuJKPK3vaz2HLk7jUzs0r1TW1waEQ8Jen71XbtJomI+Hse67Z5WLECWraEVq2KXRMzs6Krr6XzHeAp4HtZ9gXgoFOfiueuKdvXjszMtiz1TW0wKl2fXpjqbIb8CBwzs0r1da/9tK79EXFD41ZnM+SHfZqZVaqve22bdL07yVQGFd97+R4wNV+V2qx4WgMzs0r1da9dBSDpSaBf+tgZJF0JjMt77TYHK1ZAly7FroWZWZOQ6xMJugNrMrbXAD0avTabI7d0zMwq5frl0LuAFyU9TDJq7TgynupsdfBAAjOzSrl+OfRXkp4gma4a4PSImJG/am0m1q+Hzz5z0DEzS+Xa0iEipkuaD7QCkNQ9IublrWabg1WrknWbNsWth5lZE5HrfDpD0ydCvw/8b7p+vO6jjLXpk4JatixuPczMmohcBxJcAwwA3omInsAgYFrearW5WJOOvWjRorj1MDNrInINOmvT6QiaSWoWEU8DffNYr81DRdBxS8fMDMj9ns5ySW1JvhB6j6TFwLr8VWsz4aBjZlZFri2dY4CVwMXAE8B7ZH8IqGXyPR0zsyrqbelIKgEeSWfl3ACMzXutNhe+p2NmVkW9LZ2IWA+slOQvm2wsd6+ZmVWRa/faKuBVSbdJuqliaehJJbWXNFHS7HTdrpZ8w9M8syUNz0ifIultSTPTpXOafpqkJRnpI+orK6/cvWZmVkWuAwn+mS6QPAYHkmmrG2okMDkiRksamW5fmplBUntgFFCannO6pPER8Uma5ZSIKMtS9gMR8ZONLCs/3L1mZlZFffPpHAN0i4j/TrdfBDqRfHBfWtex9TgGGJi+HgtMyVLeEcDEiFiWnnsiMAS4rwHna8yycufuNTOzKurrXvs5X86hA9AS6E8SMM7ehPNuHxELAdJ15yx5ugLzM7bL07QKd6RdaFdIVeaC/v8kzZL0kKQdcyyrkqSzJJVJKluyZMlGvq1qHHTMzKqoL+i0jIjMD+tnImJZ+sy1res6UNIkSa9lWY7JsW7Zuu8quvZOiYi9gYPT5dQ0/VGgR0T0ASbx5Ui7usqqmhgxJiJKI6K0U6dOOVa1Fr6nY2ZWRX1Bp8oN/mr3Sur8RI6IQRGxV5blEWCRpC4A6XpxliLKgR0ztrsBC9KyP0zXnwL3Avum20sjYnWa/xaSVlmdZeWV7+mYmVVRX9B5QdJ/VE+U9GPgxU0473igYgTZcOCRLHkmAIMltUtHtw0GJkhqLqljWo8WwNHAa+l25hSdQ4E36yprE+qfG3evmZlVUd/otYuBf0j6AfBymtYf2Ao4dhPOOxp4UNKZwDzgBABJpcDZETEiIpZJugZ4KT3m6jRta5Lg0wIoIelGuyXNc4GkoSSP6FkGnAZQW1mbUP/cuHvNzKwKRWS9tVE1k3QosGe6+XpEPJXXWjURpaWlUVaWbVR2jsaMgR//GD78EHbYofEqZmbWhEmaHhGl2fblOnPoU8AWEWgalbvXzMyqyPWJBNYQDjpmZlU46OST7+mYmVXhoJNPHjJtZlaFg04+rVkDzZpBSUmxa2Jm1iQ46OTTmjXuWjMzy+Cgk09r17przcwsg4NOPrmlY2ZWhYNOPjnomJlV4aCTT+5eMzOrwkEnn9zSMTOrwkEnnxx0zMyqcNDJJwcdM7MqHHTyyfd0zMyqcNDJJ7d0zMyqcNDJJwcdM7MqHHTyyd1rZmZVOOjkk1s6ZmZVOOjkk4OOmVkVRQk6ktpLmihpdrpuV0u+4Wme2ZKGZ6RPkfS2pJnp0jlNvzEj7R1JyzOOWZ+xb3z+3yUOOmZm1TQv0nlHApMjYrSkken2pZkZJLUHRgGlQADTJY2PiE/SLKdERFnmMRFxccbx5wP7ZOz+IiL6Nv5bqYPv6ZiZVVGs7rVjgLHp67HAsVnyHAFMjIhlaaCZCAzZiHMMA+7bpFpuKrd0zMyqKFbQ2T4iFgKk685Z8nQF5mdsl6dpFe5Iu8qukKTMAyXtBPQEnspIbiWpTNLzkrIFuYpjz0rzlS1ZsmQj31Y1DjpmZlXkrXtN0iTg61l2XZZrEVnSIl2fEhEfStoG+BtwKvDXjHwnAw9FxPqMtO4RsUDSzsBTkl6NiPdqnCBiDDAGoLS0NKrv3yjuXjMzqyJvQSciBtW2T9IiSV0iYqGkLsDiLNnKgYEZ292AKWnZH6brTyXdC+xLzaBzXrX6LEjXcyRNIbnfUyPoNCq3dMzMqihW99p4oGI02nDgkSx5JgCDJbVLR7cNBiZIai6pI4CkFsDRwGsVB0naHWgHPJeR1k7SVunrjsCBwBuN/q6qc9AxM6uiWKPXRgMPSjoTmAecACCpFDg7IkZExDJJ1wAvpcdcnaZtTRJ8WgAlwCTgloyyhwH3R0Rm19g3gP+RtIEk0I6OiPwGnQhYt85Bx8wsQ1GCTkQsBQ7Lkl4GjMjYvh24vVqez4H+dZR9ZZa0Z4G9G17jBli7Nln7no6ZWSU/kSBf1qxJ1m7pmJlVctDJFwcdM7MaHHTyxd1rZmY1OOjki1s6ZmY1OOjki4OOmVkNDjr54qBjZlaDg06++J6OmVkNDjr54paOmVkNDjr54qBjZlaDg06+uHvNzKwGB518cUvHzKwGB518cdAxM6vBQSdfKoKOu9fMzCo56ORLxT0dt3TMzCo56OSLu9fMzGpw0MkXBx0zsxocdPLF93TMzGpw0MkX39MxM6vBQSdf3L1mZlZD82KdWFJ74AGgBzAXODEiPsmSbzhwebp5bUSMTdNbAn8CBgIbgMsi4m+StgL+CvQHlgInRcTc9JhfAGcC64ELImJCnt6eu9dss7d27VrKy8tZtWpVsatiRdKqVSu6detGi434nCta0AFGApMjYrSkken2pZkZ0sA0CigFApguaXwanC4DFkfEbpKaAe3Tw84EPomIXpJOBn4LnCSpN3AysCewAzBJ0m4RsT4v786PwbHNXHl5Odtssw09evRAUrGrYwUWESxdupTy8vWtlUcAABdiSURBVHJ69uyZ83HF7F47Bhibvh4LHJslzxHAxIhYlgaaicCQdN8ZwG8AImJDRHycpdyHgMOU/EUcA9wfEasj4n3gXWDfRn5PX1qzBkpKksVsM7Rq1So6dOjggLOFkkSHDh02uqVbzKCzfUQsBEjXnbPk6QrMz9guB7pK2i7dvkbSy5LGSdq++jERsQ5YAXSorazqJ5R0lqQySWVLlixp+Ltbs8b3c2yz54CzZWvIzz+vQUfSJEmvZVmOybWILGlB0i3YDZgWEf2A54Dr6zmmtvSqCRFjIqI0Iko7deqUYzWzWLPGXWtmZtXkNehExKCI2CvL8giwSFIXgHS9OEsR5cCOGdvdgAUkAwRWAg+n6eOAftWPkdQc+BqwrI6y8mPtWrd0zPJMEqeeemrl9rp16+jUqRNHH330RpXTo0cPPv744wbl+eyzzzjnnHPYZZdd2Geffejfvz+33HLLRp2/ujvvvJOf/OQnANx888389a9/bVA5c+fO5d57792kujS2YnavjQeGp6+HA49kyTMBGCypnaR2wGBgQkQE8CjJyDWAw4A3spR7PPBUmn88cLKkrST1BHYFXmzct5TB3Wtmebf11lvz2muv8cUXXwAwceJEunat0WueVyNGjKBdu3bMnj2bGTNm8MQTT7Bs2bIa+davb9iYpbPPPpsf/ehHDTq2KQadYo5eGw08KOlMYB5wAoCkUuDsiBgREcskXQO8lB5zdURU/DQvBe6S9AdgCXB6mn5bmv4uSQvnZICIeF3SgyTBaR1wXt5GroG712zLctFFMHNm45bZty/84Q/1ZjvyyCP55z//yfHHH899993HsGHD+Pe//w3AsmXLOOOMM5gzZw5t2rRhzJgx9OnTh6VLlzJs2DCWLFnCvvvuS/J/aeLuu+/mpptuYs2aNey33378+c9/pqSWAUHvvfceL774Ivfeey/NmiX/w3fq1IlLL00G4k6ZMoWrrrqKLl26MHPmTN544w2OPfZY5s+fz6pVq7jwwgs566yzALjjjjv4zW9+Q5cuXdhtt93YaqutALjyyitp27Ytl1xyCe+99x7nnXceS5YsoU2bNtxyyy3ssccenHbaaWy77baUlZXx0Ucfcd1113H88cczcuRI3nzzTfr27cvw4cO5+OKLG/7zaCRFa+lExNKIOCwidk3Xy9L0sogYkZHv9ojolS53ZKR/EBHfjog+6fHz0vRVEXFCmn/fiJiTccyvImKXiNg9Ih7P6xt095pZQZx88sncf//9rFq1ilmzZrHffvtV7hs1ahT77LMPs2bN4te//nVli+Gqq67ioIMOYsaMGQwdOpR58+YB8Oabb/LAAw8wbdo0Zs6cSUlJCffcc0+t53799df55je/WRlwsnnxxRf51a9+xRtvJJ0xt99+O9OnT6esrIybbrqJpUuXsnDhQkaNGsW0adOYOHFiZd7qzjrrLP74xz8yffp0rr/+es4999zKfQsXLuSZZ57hscceY+TIkQCMHj2agw8+mJkzZzaJgAPFbels3ty9ZluSHFok+dKnTx/mzp3Lfffdx3e/+90q+5555hn+9re/AXDooYeydOlSVqxYwdSpU/n73/8OwFFHHUW7du0AmDx5MtOnT+db3/oWAF988QWdO2cbWJvdr371K8aNG8fixYtZsCC5ZbzvvvtW+R7LTTfdxMMPJ7ej58+fz+zZs/noo48YOHAgFYOXTjrpJN55550qZX/22Wc8++yznHDCCZVpq1evrnx97LHH0qxZM3r37s2iRYtyrnOhOejki4OOWcEMHTqUSy65hClTprB06dLK9MxuswoVw3yzDfeNCIYPH85vfvObnM7bu3dvXnnlFTZs2ECzZs247LLLuOyyy2jbtm1lnq233rry9ZQpU5g0aRLPPfccbdq0YeDAgZXfc6lv+PGGDRvYbrvtmFlLN2ZFd1zF+2iq/Oy1fPE9HbOCOeOMM/jlL3/J3nvvXSX929/+dmX32JQpU+jYsSPbbrttlfTHH3+cTz5JnsB12GGH8dBDD7F4cTKYdtmyZXzwwQe1nrdXr16UlpZy+eWXVw4UWLVqVa0f+itWrKBdu3a0adOGt956i+effx6A/fbbrzJgrl27lnHjxtU4dtttt6Vnz56V+yKCV155pc7rss022/Dpp5/WmafQHHTyxfd0zAqmW7duXHjhhTXSr7zySsrKyujTpw8jR45k7NjkYSWjRo1i6tSp9OvXjyeffJLu3bsDScvl2muvZfDgwfTp04fDDz+chQsX1nnuW2+9laVLl9KrVy/69+/PoEGD+O1vf5s175AhQ1i3bh19+vThiiuuYMCAAQB06dKFK6+8kv33359BgwbRr1+/rMffc8893HbbbXzzm99kzz335JFHsg36/VKfPn1o3rw53/zmN7nxxhvrzFsoasrNsGIrLS2NsrKyhh188MFJ0Jk8uXErZdZEvPnmm3zjG98odjWsyLL9HkiaHhGl2fK7pZMv7l4zM6vBQSdf3L1mZlaDg06+ePSamVkNDjr54qBjZlaDg06++J6OmVkNDjr54ns6ZmY1OOjki7vXzPKupKSEvn37stdee3HCCSewcuXKWvM29hOXBw4cSC5fqZg9ezZHH300u+yyC/379+eQQw5h6tSpm3Tu0047jYceeghInnJd27Pa6jNlyhSeffbZrPtWrlzJUUcdxR577MGee+5Z+Ty3TeWgky/uXjPLu9atWzNz5kxee+01WrZsyc0331xr3oYEnXXr1m1S/VatWsVRRx3FWWedxXvvvcf06dP54x//yJw5c2rkbei5br31Vnr37t2gY+sKOgCXXHIJb731FjNmzGDatGk8/vimPyfZz17LF3ev2RakiDMbVDr44IOZNWsWV1xxBR07dqx8QsFll13G9ttvz7333lvlMf/nnHMO55xzDmVlZTRv3pwbbriBQw45hDvvvJN//vOfrFq1is8//5ynnnqK6667jrvuuotmzZpx5JFHMnr0aADGjRvHueeey/Lly7nttts4+OCDq9TpnnvuYf/992fo0KGVaXvttRd77bUXkDwxYcGCBcydO5eOHTvy61//mlNPPZXPP/8cgD/96U8ccMABRATnn38+Tz31FD179qzymJ2BAwdy/fXXU1paypNPPsmoUaNYvXo1u+yyC3fccQdt27alR48eDB8+nEcffbTyMTutWrXi5ptvpqSkhLvvvps//vGPVerfpk0bDjnkEABatmxJv379KC8v34ifYHYOOvni7jWzglm3bh2PP/44Q4YM4cgjj+T73/8+F154IRs2bOD+++/nxRdfpE+fPlx//fU89thjAPz+978H4NVXX+Wtt95i8ODBlU92fu6555g1axbt27fn8ccf5x//+AcvvPACbdq0qTJB27p163jxxRf517/+xVVXXcWkSZOq1Ov111+v9ZE2FaZPn84zzzxD69atWblyJRMnTqRVq1bMnj2bYcOGUVZWxsMPP8zbb7/Nq6++yqJFi+jduzdnnHFGlXI+/vhjrr32WiZNmsTWW2/Nb3/7W2644QZ++ctfAtCxY0defvll/vznP3P99ddz6623cvbZZ1fO1VOX5cuX8+ijj2Z91NDGctDJh/Xrk8Xda7aFKNbMBl988QV9+/YFkpbOmWeeScuWLenQoQMzZsxg0aJF7LPPPnTo0KHGsc888wznn38+AHvssQc77bRTZdA5/PDDad++PQCTJk3i9NNPp02bNgCV6QDf//73Aejfvz9z586tt77HHXccs2fPZrfddqucWmHo0KG0bt0agLVr1/KTn/ykci6fivpMnTqVYcOGUVJSwg477MChhx5ao+znn3+eN954gwMPPBCANWvWsP/++2eta8W5c7Fu3TqGDRvGBRdcwM4775zzcbVx0MmHtWuTtVs6ZnlVcU+nuhEjRnDnnXfy0Ucf1WgRVKjruZOZ0xFERK3TDlRMJ1BSUpL1nsyee+5ZZdDAww8/TFlZWZWWRea5brzxRrbffvvK6RJatWpVua++qQ8igsMPP5z77ruvQXVdv349/fv3B5JAePXVVwPJxHG77rorF110UZ3nz5UHEuSDg45ZUR133HE88cQTvPTSSxxxxBFAzcf8Z05v8M477zBv3jx23333GmUNHjyY22+/vXJkXGb3Wn1+8IMfMG3aNMaPH1+ZVtcIuxUrVtClSxeaNWvGXXfdVTldwre//W3uv/9+1q9fz8KFC3n66adrHDtgwACmTZvGu+++W3me6hPBVZd5TUpKSpg5cyYzZ86sDDiXX345K1as4A+N2JR10MmHNWuStYOOWVG0bNmSQw45hBNPPJGSkhKg5mP+zz33XNavX8/ee+/NSSedxJ133lllIrQKQ4YMYejQoZSWltK3b1+uv/76nOvRunVrHnvsMW6++WZ23nln9t9/f6699louv/zyrPnPPfdcxo4dy4ABA3jnnXcqW0HHHXccu+66K3vvvTfnnHMO3/nOd2oc26lTJ+68806GDRtGnz59GDBgAG+99Vad9fve977Hww8/TN++ffn3v/9dZV95eXnlNNv9+vWjb9++3HrrrTm/99oUZWoDSe2BB4AewFzgxIj4JEu+4UDFT+faiBibprcE/gQMBDYAl0XE3yT9FBgBrAOWAGdExAfpMeuBV9Oy5kXEl8NJatHgqQ2WL4cf/xjOOAPS/7LMNjdNeWqDDRs20K9fP8aNG8euu+5a7Ops1r4qUxuMBCZHxK7A5HS7ijQwjQL2A/YFRklql+6+DFgcEbsBvYH/TdNnAKUR0Qd4CLguo8gvIqJvutQbcDbJdtvBAw844JgVwRtvvEGvXr047LDDHHCaoGINJDiGpJUCMBaYAlxaLc8RwMSIWAYgaSIwBLgPOAPYAyAiNgAfp68zOzqfB36Yl9qbWZPVu3fvrF++tKahWC2d7SNiIUC67pwlT1dgfsZ2OdBV0nbp9jWSXpY0TtL2WY4/E8j8+mwrSWWSnpd0bG0Vk3RWmq9syZIlG/WmzLY0nnl4y9aQn3/ego6kSZJey7Ick2sRWdKCpHXWDZgWEf2A54Aqd/Yk/RAoBX6Xkdw97WP8AfAHSbtkO2lEjImI0ogo7dSpU45VNdvytGrViqVLlzrwbKEigqVLl1YZ1p2LvHWvRcSg2vZJWiSpS0QslNQFWJwlWzlfdsFBEmimAEuBlcDDafo4klZNRdmDSO75fCciVmfUZ0G6niNpCrAP8N5GvzEzA6Bbt26Ul5fjHoEtV6tWrejWrdtGHVOsezrjgeHA6HT9SJY8E4BfZwweGAz8IiJC0qMkAekp4DDgDQBJ+wD/AwyJiMpAlpaxMiJWS+oIHEjVQQZmtpFatGhBz549i10N+4opVtAZDTwo6UxgHnACgKRS4OyIGBERyyRdA7yUHnN1xaACkkEHd0n6A8nQ6NPT9N8BbYFx6bd3K4ZGfwP4H0kbSLoUR0dEw54FbmZmDVaU7+l8VTT4ezpmZluwpvg9HTMz2wK5pVMHSUuADxpwaEfS7w41Ya5j43AdG4fr2HiaQj13ioisw38ddPJAUlltTcumwnVsHK5j43AdG09Tr6e718zMrGAcdMzMrGAcdPJjTLErkAPXsXG4jo3DdWw8TbqevqdjZmYF45aOmZkVjIOOmZkVjINOI5I0RNLbkt6VVGNiumKQtKOkpyW9Kel1SRem6e0lTZQ0O123q6+sAtS1RNIMSY+l2z0lvZDW8YF0xthi1m87SQ9Jeiu9nvs30et4cfqzfk3SfZJaFftaSrpd0mJJr2WkZb12StyU/h3NktSviHX8XfrzniXp4YypVZD0i7SOb0sqyIyN2eqYse8SSZE+X7Jo17E+DjqNRFIJ8N/AkSSzmQ6T1Lu4tQKSqbt/FhHfAAYA56X1qnf21iK4EHgzY/u3wI1pHT8h42niRfJfwBMRsQfwTZK6NqnrKKkrcAHJDLp7ASXAyRT/Wt5JMgljptqu3ZHArulyFvCXItZxIrBXOhvxO8AvANK/oZOBPdNj/px+BhSjjkjaETic5FmWFYp1HevkoNN49gXejYg5EbEGuJ9khtSiioiFEfFy+vpTkg/KriR1G5tmGwvUOrFdIUjqBhwF3JpuCziUZNpxKHIdJW0LfBu4DSAi1kTEcprYdUw1B1pLag60ARZS5GsZEVOBZdWSa7t2xwB/jcTzwHbpFCgFr2NEPBkR69LN50mmWKmo4/0RsToi3gfeJfkMKHgdUzcCPyeZc6xCUa5jfRx0Gk/WmU6LVJesJPUgmUfoBXKbvbWQ/kDyR7Mh3e4ALM/4gy/29dyZ5Inmd6RdgLdK2pomdh0j4kOSSQ3nkQSbFcB0mta1rFDbtWuqf0tn8OVsxE2mjpKGAh9GxCvVdjWZOmZy0Gk8tc102iRIagv8DbgoIv6v2PXJJOloYHFETM9MzpK1mNezOdAP+EtE7AN8TtPokqwivS9yDNAT2AHYmqSbpbom87uZRVP72SPpMpKu6nsqkrJkK3gdJbUhmbTyl9l2Z0kr+s/dQafxlAM7Zmx3AxYUqS5VSGpBEnDuiYi/p8mLKpraqn321kI5EBgqaS5Jt+ShJC2f7dIuIij+9SwHyiPihXT7IZIg1JSuI8Ag4P2IWBIRa4G/AwfQtK5lhdquXZP6W5I0HDgaOCW+/GJjU6njLiT/YLyS/v10A16W9HWaTh2rcNBpPC8Bu6ajhFqS3GQcX+Q6VdwbuQ14MyJuyNhVMXsr1D57a0FExC8ioltE9CC5bk9FxCnA08DxabZi1/EjYL6k3dOkihlrm8x1TM0DBkhqk/7sK+rZZK5lhtqu3XjgR+noqwHAiopuuEKTNIRk0sihEbEyY9d44GRJW0nqSXKz/sVC1y8iXo2IzhHRI/37KQf6pb+vTeY6VhERXhppAb5LMsLlPeCyYtcnrdNBJE3qWcDMdPkuyT2TycDsdN2+2HVN6zsQeCx9vTPJH/K7wDhgqyLXrS9Qll7LfwDtmuJ1BK4C3gJeA+4Ctir2tQTuI7nHtJbkg/HM2q4dSbfQf6d/R6+SjMQrVh3fJbkvUvG3c3NG/svSOr4NHFmsOlbbPxfoWMzrWN/ix+CYmVnBuHvNzMwKxkHHzMwKxkHHzMwKxkHHzMwKxkHHzMwKxkHHrJFI6pHt6b/V8twp6fj09UXpN8ob6/zHZj5kVtLVkgY1VvlmjcFBx6x4LiJ5IGfO6nmS8bEkTzgHICJ+GRGTGlg3s7xw0DHLA0k7pw8G/VYt+y8geTba05KeTtMGS3pO0suSxqXPy0PSXEm/lPQMcIKk/5D0kqRXJP0tffrAAcBQ4HeSZkrapVqr6rC0Pq+mc7JslVH2Vek5X5W0R5r+nbScmelx2+T9otkWwUHHrJGlj8r5G3B6RLyULU9E3ETyHKxDIuKQdOKty4FBEdGP5MkHP804ZFVEHBQR9wN/j4hvRUTFnD5nRsSzJI89+X8R0Tci3suoTyuSeVhOioi9SR5eek5G2R+n5/wLcEmadglwXkT0BQ4GvtiUa2JWwUHHrHF1InmG2A8jYuZGHDeApGtsmqSZJM8i2ylj/wMZr/eS9G9JrwKnkEwkVpfdSR4C+k66PZZkbqAKFQ+BnQ70SF9PA25IW2TbxZfTIphtEgcds8a1guRZXQcCSLoj7aL6Vz3HCZiYtlL6RkTviMic3fPzjNd3Aj9JWy1XAa1yKLsuq9P1epJWEBExGhgBtAaer+h2M9tUzevPYmYbYQ3JDf0Jkj6LiNPryPspsA3wMcmslP8tqVdEvJuOauuW0TrJtA2wMJ2y4hTgw2rlVfcW0KOibOBU4H/rehOSdomIV4FXJe0P7JGWY7ZJ3NIxa2QR8TnJ/CsXS6pryvIxwOOSno6IJcBpwH2SZpEEodpaF1eQzP46kaqB4H7g/6U3/nfJqM8q4HRgXNoltwG4uZ63cZGk1yS9QnI/5/F68pvlxE+ZNjOzgnFLx8zMCsZBx8zMCsZBx8zMCsZBx8zMCsZBx8zMCsZBx8zMCsZBx8zMCub/B0XkyPGhA552AAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "## Plot result to see convergence\n", - "\n", - "fig, ax = plt.subplots()\n", - "ax.plot(range(1,k), dx2_dz3_list, c = \"red\", linestyle = \"solid\", label = \"Model Gradient\")\n", - "\n", - "plt.axhline(y = pytorch_grad2, color = 'blue', linestyle = '-', label = \"Pytorch Gradient-2\")\n", - "plt.legend(loc=\"lower right\")\n", - "plt.title(\"Gradient Plot for dx2_dz3\")\n", - "plt.ylabel(\"Gradient\")\n", - "plt.xlabel(\"k-Iterations\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
dz1dz2dz3dz1dz2dz3dz1dz2dz3
dx1Nonetensor(0.3823)tensor(1.1035)[[tensor(0.)]][[tensor(0.3823)]][[tensor(0.0193)]]tensor(0.)tensor(0.5161)tensor(0.0185)
dx2tensor(1.)tensor(0.8828)tensor(-2.5484)[[tensor(1.)]][[tensor(0.8828)]][[tensor(-0.0445)]]tensor(1.)tensor(1.1918)tensor(-0.0491)
dx3Nonetensor(0.1766)tensor(0.5097)[[tensor(0.)]][[tensor(0.1766)]][[tensor(0.0089)]]tensor(0.)tensor(0.1766)tensor(0.0084)
\n", - "
" - ], - "text/plain": [ - " dz1 dz2 dz3 dz1 \\\n", - "dx1 None tensor(0.3823) tensor(1.1035) [[tensor(0.)]] \n", - "dx2 tensor(1.) tensor(0.8828) tensor(-2.5484) [[tensor(1.)]] \n", - "dx3 None tensor(0.1766) tensor(0.5097) [[tensor(0.)]] \n", - "\n", - " dz2 dz3 dz1 dz2 \\\n", - "dx1 [[tensor(0.3823)]] [[tensor(0.0193)]] tensor(0.) tensor(0.5161) \n", - "dx2 [[tensor(0.8828)]] [[tensor(-0.0445)]] tensor(1.) tensor(1.1918) \n", - "dx3 [[tensor(0.1766)]] [[tensor(0.0089)]] tensor(0.) tensor(0.1766) \n", - "\n", - " dz3 \n", - "dx1 tensor(0.0185) \n", - "dx2 tensor(-0.0491) \n", - "dx3 tensor(0.0084) " - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Combine results\n", - "result = pd.concat([jacob_mat_1, jacob_mat_2, jacob_mat_3], axis=1, join='inner')\n", - "result" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Since the gradient values match, we continue applying NOFAS to our Phys model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Jubilee's Note to Dr. Schiavazzi: Gradient calculation using pytorch (method 1) have different outputs.\n", - "- As you can see below from the plot and from the **result** data above, gradient calculation using pytorch1 and pytorch2 have different values for the third input parameter (z3) which I'm not sure. \n", - "- I'm most positive that we should get rid of the first method and stick to pytorch2 and manual gradient computation but I wanted you to double confirm on this thought." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEWCAYAAAB42tAoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3de3wU5b3H8c+PcL9YEVDBiFxEbUAMkCJ4BUULiii+tIrWoujhCFrFHo/SogJWq7ZULbaW4g2LCoqVesUKKqXgBYMEkDtalBQQiMJRMELgd/6YSdyEzWQJ2exKvu/Xa1+788wzz/x2NtnfzjMzz5i7IyIiUp5aqQ5ARETSmxKFiIhEUqIQEZFIShQiIhJJiUJERCIpUYiISCQlChERiaREId8LZrbWzPqEr39lZo+mQUxXmtncKmxvoJmtM7OvzaxLVbUb0/4YM3uqqtuNWF8bM3Mzq11d65TkUKKQ/WZml5rZ+2a23cw2ha+Hm5klY33u/ht3v2Z/20nkiyz8ct0VfnlvNbN3zKxnJdY128wqinkccL27N3b3hfu6jqpiZoea2RQzW29m28xsnpmdmMT1XWpmK8N1bTKzJ83soGStT/adEoXsFzP7H+APwO+Aw4HDgGuBk4G65SyTUW0BVo1n3b0x0AKYC7yQpCR4FLC0MgtW8TZtDHwAdAMOAZ4EXjWzxlW4jljzgJPd/QdAO6A2cFeS1iWVoEQhlWZmPwDuBIa7+/Pu/pUHFrr75e7+bVhvkpn92cxeM7PtQG8zO9fMFprZ/4XdLWPKtH2FmX1qZgVmNqrMvFJdKGbWI/ylv9XMFplZr5h5s83s1+Gv4q/M7A0zax7OnhM+bw33GCL3FNx9F8GX5uFAszjb4yQz+yD8ZfyBmZ0Ult8NnAr8MVzPH8ssV8/MvgYygEVm9nFY/sMw/q1mttTMBsQss9c2jRNPWzP7Z/i+ZwLNY+ZdYmafFP9yN7N+ZrbRzFq4+yfufr+7b3D33e4+kSDpHxu1fcwsw8zGmdkWM/sEODdmXs/wvRc/Cs1sbbhd17n7lpimdgNHR61Lqpm766FHpR5AX6AIqF1BvUnANoK9jFpAfaAXcHw43Rn4HLggrJ8FfA2cBtQD7g/X0yecPwZ4Knx9BFAAnBO2dVY43SKcPxv4GDgGaBBO3xvOawN4VPxl1lWPYM9pXTh9JTA3fH0I8CVwBcEv4kHhdLOYOK6pYDs5cHT4ug6wBvgVwZf0GcBXwLHlbdM47b0bbrt64bb8qvi9hPOfDttpBqwH+pcTVzZQCPyggvivBVYAR4bb4+142zd8b7OBe2LKTgnfjwPbgbNT/fetx3cP7VHI/mgObHH3ouKCmF/235jZaTF1X3T3ee6+x90L3X22uy8JpxcDU4DTw7oXAa+4+xwP9kpuB/aUE8NPgdfc/bWwrZlALkHiKPaEu69y92+A5wi++PbFT8xsK7COoDvmgjh1zgVWu/tkdy9y9ykEX5rn7eO6ivUg6AK61913uvtbwCsECahYqW0au7CZtQZ+BNzu7t+6+xzg5TLruI4gAc0GXnb3V8oGEe5xTAbGuvu2CmL+CfCgB3sIXwD3lFNvPEEyKNlTdPe5HnQ9ZRIk47UVrEuqkRKF7I8CoHnswWB3P8ndDw7nxf59rYtd0MxONLO3zWyzmW0j+DVa3DXSKra+u28P24vnKODiMDltDb/QTwFaxtTZGPN6B8EX8L54zt0PdvdD3f0Md18Qp04r4NMyZZ8S7PFURiuCPZfYBFm2vXWUrxXwZbjtYpcv4e5bgWlAJ+D3ZRswswYEyeU9dy/vS3+vmMtbX9jmfxPsTV5W5r0Vx/Qf4HVgagLrk2qiRCH7413gW+D8BOqWHc/+GeAl4Mjwl+QEoPgA8QaC7gsAzKwhcY4JhNYBk8Mv8uJHI3e/txIx7Y/1BEkrVmvgP5Vc13rgSDOL/R+Nba+iNjcATc2sUZnlS5hZNjCEYG9ufJl59YC/h+v77wRjLvW5xVnfqcCvgfMr2DupDbRPcJ1SDZQopNLCX6RjgYfN7CIza2xmtcIvoEYVLN4E+MLdC82sO3BZzLzngf5mdoqZ1SU4YF7e3+pTwHlm9uPwYGp9M+tlZpkJvIXNBF1a7RKoW5HXgGPM7DIzq21mlxAcaynuzvl8H9fzPkH3zC1mVic8QH8eCf7SdvdPCbrgxppZXTM7hZhuMDOrT7DtfgVcBRxhZsPDeXUIPoNvgJ/F++VfjueAG8ws08yaAiNj1nck8GzY3qrYhczscjNrbYGjgLuBNxNcp1QDJQrZL+7+W+AXwC3AJoIvxL8AtwLvRCw6HLjTzL4C7iD4kilucylB//kzBL9SvwTyy1n/OoI9ml8RfPGvA/6XBP623X0HwZfSvLDbqkdFy0S0VQD0B/6HoJvsFoKDw8Vn8/wBuMjMvjSz8eU0E9veTmAA0A/YAjxM8CW7Yh/Cugw4EfgCGA38NWbePUC+u/85PA70U+AuM+sAnBS+l7P57oywr8M9giiPAP8AFgEfAi/EzDuT4Gyx52PaKz4VOIvgb+VrglNlVwL/tQ/vU5LM3HWHOxERKZ/2KEREJJIShYgkzMwmlLlwrvgxIdWxSfKo60lERCIdkKM6Nm/e3Nu0aZPqMEREvjcWLFiwxd1bxJt3QCaKNm3akJubm+owRES+N8xsrwski+kYhYiIRFKiEBGRSClNFGbW14Iblqwxs5Fx5tczs2fD+e+bWZvqj1JEpGZLWaKw4EYrfyK48jQLGGRmWWWqXU0wsNnRwAPAfdUbpYiIpHKPojuwxoObpOwkGMOm7OBy5xPcKAaCsWfONEvO7TVFRCS+VCaKIyg9JHE+ew/JXFInvOfBNsoZRdTMhppZrpnlbt68OQnhiojUTKlMFPH2DMpe/ZdInaDQfaK757h7TosWcU8FFhGRSkjldRT5lB67PpNgDP54dfLDm+P8gGAkzEgrV0KvXlURosOePbB7d/Ao2h1M4+AJPMImSrWX9LIE3lNVqdKL+hVXChuTA0VGBhzZuuJ6+yiVieIDoIOZtSW4OcqllL4nAQQ3thlMcIOci4C3PJljjuzaCQUFsKUAdmyHwm8h4aH4RURSrE7dAytRuHuRmV1PMH59BvC4uy81szuBXHd/CXgMmGxmawj2JC5NpO1jj4XZs/cxoB07oFkzKCyE1q2hd/fg+dBDoXFjaNIkeK5fH+rUgdq1g0dGxnevY8syMsDsuweUnk5mWUUSPR9AbaWmLZ2vISkQ9WeX0iE83P01gjuDxZbdEfO6ELi4WoJp2BD+/Gc44QTIztY/q4hI6IAc66nSrrwy1RGIiKQdDeEhIiKRlChERCSSEoWIiERSohARkUhKFCIiEkmJQkREIilRiIhIJCUKERGJpEQhIiKRlChERCSSEoWIiERSohARkUhKFCIiEkmJQkREIilRiIhIJCUKERGJpEQhIiKRlChERCSSEoWIiERSohARkUhKFCIiEkmJQkREIilRiIhIJCUKERGJpEQhIiKRlChERCSSEoWIiERSohARkUhKFCIiEiklicLMDjGzmWa2OnxuWk693WaWFz5equ44RUQkdXsUI4E33b0D8GY4Hc837p4dPgZUX3giIlIsVYnifODJ8PWTwAUpikNERCqQqkRxmLtvAAifDy2nXn0zyzWz98wsMpmY2dCwbu7mzZurOl4RkRqrdrIaNrNZwOFxZo3ah2Zau/t6M2sHvGVmS9z943gV3X0iMBEgJyfH9zlgERGJK2mJwt37lDfPzD43s5buvsHMWgKbymljffj8iZnNBroAcROFiIgkR6q6nl4CBoevBwMvlq1gZk3NrF74ujlwMrCs2iIUEREgdYniXuAsM1sNnBVOY2Y5ZvZoWOeHQK6ZLQLeBu51dyUKEZFqlrSupyjuXgCcGac8F7gmfP0OcHw1hyYiImXoymwREYmkRCEiIpGUKEREJJIShYiIRFKiEBGRSEoUIiISSYlCREQiKVGIiEgkJQoREYmkRCEiIpGUKEREJJIShYiIRFKiEBGRSEoUIiISSYlCREQiKVGIiEgkJQoREYmkRCEiIpGUKEREJJIShYiIRFKiEBGRSEoUIiISSYlCREQiKVGIiEgkJQoREYmkRCEiIpGUKEREJJIShYiIRFKiEBGRSEoUIiISSYlCREQipSRRmNnFZrbUzPaYWU5Evb5mttLM1pjZyOqMUUREAqnao/gIuBCYU14FM8sA/gT0A7KAQWaWVT3hiYhIsdqpWKm7Lwcws6hq3YE17v5JWHcqcD6wLOkBiohIiXQ+RnEEsC5mOj8si8vMhppZrpnlbt68OenBiYjUFEnbozCzWcDhcWaNcvcXE2kiTpmXV9ndJwITAXJycsqtJyIi+yZpicLd++xnE/nAkTHTmcD6/WxTRET2UTp3PX0AdDCztmZWF7gUeCnFMYmI1DgJJQozuziRskSZ2UAzywd6Aq+a2T/C8lZm9hqAuxcB1wP/AJYDz7n70squU0REKsfcK+7ON7MP3b1rRWXpIicnx3Nzc1MdhojI94aZLXD3uNe1RR6jMLN+wDnAEWY2PmbWQUBR1YUoIiLpqqKD2euBXGAAsCCm/CvgpmQFJSIi6SMyUbj7ImCRmT3j7ruqKSYREUkjiZ4e293MxgBHhcsY4O7eLlmBiYhIekg0UTxG0NW0ANidvHBERCTdJJootrn7jKRGIiIiaSnRRPG2mf0OeAH4trjQ3T9MSlQiIpI2Ek0UJ4bPsefYOnBG1YYjIiLpJqFE4e69kx2IiIikp0SH8DjMzB4zsxnhdJaZXZ3c0EREJB0kOijgJIIxl1qF06uAEckISERE0kuiiaK5uz8H7IGSAft0mqyISA2QaKLYbmbNCG8cZGY9gG1Ji0pERNJGomc9/YLgXhDtzWwe0AK4KGlRiYhI2kj0rKcPzex04FiC4TtWauwnEZGaoaJhxs9w97fM7MIys44xM9z9hSTGJiIiaaCiPYrTgbeA8+LMc4IrtUVE5ABW0TDjo8Pnq6onHBERSTcVdT39Imq+u99fteGIiEi6qajrqUn4fCzwI4IznyDoipqTrKBERCR9VNT1NBbAzN4Aurr7V+H0GGBa0qMTEZGUS/SCu9bAzpjpnUCbKo9GRETSTqIX3E0G5pvZdIKznQYCf01aVCIikjYSveDubjN7HTglLLrK3RcmLywREUkXie5R4O4LzGwdUB/AzFq7+2dJi0xERNJCovejGGBmq4F/A/8Mn3UPbRGRGiDRg9m/BnoAq9y9LdAHmJe0qEREJG0kmih2uXsBUMvMarn720B2EuMSEZE0kegxiq1m1pjgIrunzWwTUJS8sEREJF0kukdxPrADuAl4HfiY+AMFiojIAabCRGFmGcCL7r7H3Yvc/Ul3Hx92RVWKmV1sZkvNbI+Z5UTUW2tmS8wsz8xyK7s+ERGpvAq7ntx9t5ntMLMfuHtV3f70I+BC4C8J1O3t7luqaL0iIrKPEj1GUQgsMbOZwPbiQne/oTIrdfflAGZWmcVFRKQaJZooXg0fEAzhAcEtUZPNgTfMzIG/uPvE8iqa2VBgKEDr1q2rITQRkZqhovtRnA9kuvufwun5QAuCL/BbK1h2FnB4nFmj3P3FBOM72d3Xm9mhwEwzW+HucYc3D5PIRICcnByPV0dERPZdRXsUtwCXxkzXBboBjYEniBhq3N377G9w7r4+fN4UDkjYHd0HQ0SkWlV01lNdd18XMz3X3b8Ix3hqlMS4MLNGZtak+DVwNsFBcBERqUYVJYqmsRPufn3MZIvKrtTMBppZPtATeNXM/hGWtzKz18JqhwFzzWwRMB941d1fr+w6RUSkcirqenrfzP7L3R+JLTSz/yb48q4Ud58OTI9Tvh44J3z9CXBCZdchIiJVo6JEcRPwdzO7DPgwLOsG1AMuSGZgIiKSHiq6Z/Ym4CQzOwPoGBa/6u5vJT0yERFJC4ne4e4tQMlBRKQGSnRQQBERqaGUKEREJJIShYiIRFKiEBGRSEoUIiISSYlCREQiKVGIiEgkJQoREYmkRCEiIpGUKEREJJIShYiIRFKiEBGRSEoUIiISSYlCREQiKVGIiEgkJQoREYmkRCEiIpGUKEREJJIShYiIRFKiEBGRSEoUIiISSYlCREQiKVGIiEgkJQoREYmkRCEiIpGUKEREJJIShYiIREpJojCz35nZCjNbbGbTzezgcur1NbOVZrbGzEZWd5wiIpK6PYqZQCd37wysAn5ZtoKZZQB/AvoBWcAgM8uq1ihFRCQ1icLd33D3onDyPSAzTrXuwBp3/8TddwJTgfOrK0YREQmkwzGKIcCMOOVHAOtipvPDMhERqUa1k9Wwmc0CDo8za5S7vxjWGQUUAU/HayJOmUesbygwFKB169b7HK9ITbBr1y7y8/MpLCxMdSiSIvXr1yczM5M6deokvEzSEoW794mab2aDgf7Ame4eLwHkA0fGTGcC6yPWNxGYCJCTk1NuQhGpyfLz82nSpAlt2rTBLN5vMTmQuTsFBQXk5+fTtm3bhJdL1VlPfYFbgQHuvqOcah8AHcysrZnVBS4FXqquGEUORIWFhTRr1kxJooYyM5o1a7bPe5SpOkbxR6AJMNPM8sxsAoCZtTKz1wDCg93XA/8AlgPPufvSFMUrcsBQkqjZKvP5J63rKYq7H11O+XrgnJjp14DXqisuERHZWzqc9SQiNYiZccUVV5RMFxUV0aJFC/r3779P7bRp04YtW7ZUqs7XX3/NsGHDaN++PV26dKFbt2488sgj+7T+siZNmsT1118PwIQJE/jrX/9aqXbWrl3LM888s1+xVDUlChGpVo0aNeKjjz7im2++AWDmzJkccUT1nvl+zTXX0LRpU1avXs3ChQt5/fXX+eKLL/aqt3v37kq1f+211/Kzn/2sUsumY6JISdeTiKSBESMgL69q28zOhgcfrLBav379ePXVV7nooouYMmUKgwYN4l//+hcAX3zxBUOGDOGTTz6hYcOGTJw4kc6dO1NQUMCgQYPYvHkz3bt3J/Zkyaeeeorx48ezc+dOTjzxRB5++GEyMjLirvvjjz9m/vz5PPPMM9SqFfxWbtGiBbfeeisAs2fPZuzYsbRs2ZK8vDyWLVvGBRdcwLp16ygsLOTGG29k6NChADzxxBPcc889tGzZkmOOOYZ69eoBMGbMGBo3bszNN9/Mxx9/zHXXXcfmzZtp2LAhjzzyCMcddxxXXnklBx10ELm5uWzcuJHf/va3XHTRRYwcOZLly5eTnZ3N4MGDuemmmyr/eVQR7VGISLW79NJLmTp1KoWFhSxevJgTTzyxZN7o0aPp0qULixcv5je/+U3JL/OxY8dyyimnsHDhQgYMGMBnn30GwPLly3n22WeZN28eeXl5ZGRk8PTT8S7NCixdupQTTjihJEnEM3/+fO6++26WLVsGwOOPP86CBQvIzc1l/PjxFBQUsGHDBkaPHs28efOYOXNmSd2yhg4dykMPPcSCBQsYN24cw4cPL5m3YcMG5s6dyyuvvMLIkcFwdvfeey+nnnoqeXl5aZEkQHsUIjVXAr/8k6Vz586sXbuWKVOmcM4555SaN3fuXP72t78BcMYZZ1BQUMC2bduYM2cOL7zwAgDnnnsuTZs2BeDNN99kwYIF/OhHPwLgm2++4dBDD004lrvvvptp06axadMm1q8PLtXq3r17qesMxo8fz/Tp0wFYt24dq1evZuPGjfTq1YsWLVoAcMkll7Bq1apSbX/99de88847XHzxxSVl3377bcnrCy64gFq1apGVlcXnn3+ecMzVTYlCRFJiwIAB3HzzzcyePZuCgoKS8njX3xaf0hnv1E53Z/Dgwdxzzz0JrTcrK4tFixaxZ88eatWqxahRoxg1ahSNGzcuqdOoUaOS17Nnz2bWrFm8++67NGzYkF69epVch1DRqaZ79uzh4IMPJq+cLr7irqri95Gu1PUkIikxZMgQ7rjjDo4//vhS5aeddlpJ19Hs2bNp3rw5Bx10UKnyGTNm8OWXXwJw5pln8vzzz7Np0yYgOMbx6aeflrveo48+mpycHG677baSg9WFhYXlflFv27aNpk2b0rBhQ1asWMF7770HwIknnliS5Hbt2sW0adP2Wvaggw6ibdu2JfPcnUWLFkVulyZNmvDVV19F1qluShQikhKZmZnceOONe5WPGTOG3NxcOnfuzMiRI3nyySeB4NjFnDlz6Nq1K2+88UbJmG5ZWVncddddnH322XTu3JmzzjqLDRs2RK770UcfpaCggKOPPppu3brRp08f7rvvvrh1+/btS1FREZ07d+b222+nR48eALRs2ZIxY8bQs2dP+vTpQ9euXeMu//TTT/PYY49xwgkn0LFjR1588cXI2Dp37kzt2rU54YQTeOCBByLrVhdL592dysrJyfHc3NxUhyGSdpYvX84Pf/jDVIchKRbv78DMFrh7Trz62qMQEZFIShQiIhJJiUJERCIpUYiISCQlChERiaREISIikZQoRKRaZWRkkJ2dTadOnbj44ovZsaO8m1xW/UiqvXr1IpFT51evXk3//v1p37493bp1o3fv3syZM2e/1n3llVfy/PPPA8HoteWNDVWR2bNn884775Q7f9SoURx55JGlrjTfX0oUIlKtGjRoQF5eHh999BF169ZlwoQJ5datTKIoKirar/gKCws599xzGTp0KB9//DELFizgoYce4pNPPqmydT366KNkZWVVatmKEsV5553H/PnzK9V2eTTWk0gNNeL1EeRtrNphxrMPz+bBvokPNnjqqaeyePFibr/9dpo3b15ypfaoUaM47LDDeOaZZ0oNuT1s2DCGDRtGbm4utWvX5v7776d3795MmjSJV199lcLCQrZv385bb73Fb3/7WyZPnkytWrXo168f9957LwDTpk1j+PDhbN26lccee4xTTz21VExPP/00PXv2ZMCAASVlnTp1olOnTkBw5fj69etZu3YtzZs35ze/+Q1XXHEF27dvB+CPf/wjJ510Eu7Oz3/+c9566y3atm1baoiQXr16MW7cOHJycnjjjTcYPXo03377Le3bt+eJJ56gcePGtGnThsGDB/Pyyy+XDBFSv359JkyYQEZGBk899RQPPfTQXvEXXzlelZQoRCQlioqKmDFjBn379qVfv35ceOGF3HjjjezZs4epU6cyf/58OnfuzLhx43jllVcA+P3vfw/AkiVLWLFiBWeffXbJiK3vvvsuixcv5pBDDmHGjBn8/e9/5/3336dhw4albkpUVFTE/Pnzee211xg7diyzZs0qFdfSpUvLHY6j2IIFC5g7dy4NGjRgx44dzJw5k/r167N69WoGDRpEbm4u06dPZ+XKlSxZsoTPP/+crKwshgwZUqqdLVu2cNdddzFr1iwaNWrEfffdx/33388dd9wBQPPmzfnwww95+OGHGTduHI8++ijXXnttyb0uqosShUgNtS+//KvSN998Q3Z2NhDsUVx99dXUrVuXZs2asXDhQj7//HO6dOlCs2bN9lp27ty5/PznPwfguOOO46ijjipJFGeddRaHHHIIALNmzeKqq66iYcOGACXlABdeeCEA3bp1Y+3atRXGO3DgQFavXs0xxxxTMsz5gAEDaNCgAQC7du3i+uuvL7kXRnE8c+bMYdCgQWRkZNCqVSvOOOOMvdp+7733WLZsGSeffDIAO3fupGfPnnFjLV53KihRiEi1Kj5GUdY111zDpEmT2Lhx416/vItFjU0XOzS4u5c7BHjx0N4ZGRlxjzF07Nix1IHr6dOnk5ubW+oXfOy6HnjgAQ477LCSocvr169fMq+iYcjdnbPOOospU6ZUKtbdu3fTrVs3IEhed955Z+T6KksHs0UkLQwcOJDXX3+dDz74gB//+MfA3kNuxw41vmrVKj777DOOPfbYvdo6++yzefzxx0vOqIp3P+zyXHbZZcybN4+XXnqppCzqzKxt27bRsmVLatWqxeTJk0uGLj/ttNOYOnUqu3fvZsOGDbz99tt7LdujRw/mzZvHmjVrStZT9uZHZcVuk4yMDPLy8sjLy0takgAlChFJE3Xr1qV379785Cc/Kbnfddkht4cPH87u3bs5/vjjueSSS5g0aVKpm/8U69u3LwMGDCAnJ4fs7GzGjRuXcBwNGjTglVdeYcKECbRr146ePXty1113cdttt8WtP3z4cJ588kl69OjBqlWrSvY2Bg4cSIcOHTj++OMZNmwYp59++l7LtmjRgkmTJjFo0CA6d+5Mjx49WLFiRWR85513HtOnTyc7O7vkPuOxbrnlFjIzM9mxYweZmZmMGTMm4fdeHg0zLlKDpPMw43v27KFr165MmzaNDh06pDqcA5qGGReR751ly5Zx9NFHc+aZZypJpCEdzBaRlMvKyop7QZukB+1RiIhIJCUKERGJpEQhIiKRlChERCSSEoWIVCsNM568YcZ37NjBueeey3HHHUfHjh0ZOXJkpeONlZJEYWa/M7MVZrbYzKab2cHl1FtrZkvMLM/MdGGEyAFAw4wnd5jxm2++mRUrVrBw4ULmzZvHjBkzKrWeWKk6PXYm8Et3LzKz+4BfAreWU7e3u2+pvtBEaoYRIyDOkEv7JTsbHtyHsQY1zHjVDjPesGFDevfuDQRXunft2pX8/Px9+ATjS0micPc3YibfAy5KRRwikjoaZjy5w4xv3bqVl19+uST57o90uOBuCPBsOfMceMPMHPiLu08srxEzGwoMBWjdunWVBylyoNmXX/5VScOMfydZw4wXFRUxaNAgbrjhBtq1a5fwcuVJWqIws1nA4XFmjXL3F8M6o4Ai4OlymjnZ3deb2aHATDNb4e5xjyiFSWQiBGM97fcbEJGk0DDjpd9PMoYZHzp0KB06dGDEiBGR609U0g5mu3sfd+8U51GcJAYD/YHLvZxP393Xh8+bgOlA92TFKyKppWHGq2aY8dtuu41t27bxYBXuMqak68nM+hIcvD7d3eN+AmbWCKjl7l+Fr88GEhpwfWXBSnpN6lVV4YocMEZ3HE2tLak9K97dWbllZdx5XXp24aCDDmLNl8EXZ71W9djpOzmu03EMvHQgg64cxL9u/hfHZB1D7Yza3PmHO1n71Vo2fLWBL7/5sqTdtjlt6XlmT2i9OFwAAAk/SURBVDp36UydOnU4rc9p/OK2X7Bj1w7Wbl1Lky1N+LLgS3bt2RU3locmP8S9t9/LdTdcR7MWzWjUuBHX/PwaVm5ZyZYdW9hhO0qW63tpX24YcgOTp0ym+8ndadiwISu3rCTr1Cyav9qcY7OOpU37NnTr2Y3//N9/WLll5XdxtGnCr//wawZePJCdO3cCMOKXI/BDnF17drGmYA0FFLB261p27ArW2fGUjtw45Eae+9tz3HbPbeT0/G7A143rN3L33XfTrkM7OnbuCMDlV1/OxVdcXOr9bfx6I8MmDUv4M0vJMONmtgaoBxSERe+5+7Vm1gp41N3PMbN2BHsRECS0Z9z97kTab9K2iXcb3a3K4xb5vhvdcTSt2rZKdRhx7dmzhwvPuJAHH3uQNu3bpDqcA9r6f69n7NKxpcr+edU/yx1mXPejEKlB0vV+FMuWLaN///4MHDiw5MwmSZ59vR9FOpz1JCI1nIYZT28awkOkhjkQexEkcZX5/JUoRGqQ+vXrU1BQoGRRQ7k7BQUFpU7hTYS6nkRqkMzMTPLz89m8eXOqQ5EUqV+/PpmZmfu0jBKFSA1Sp04d2rZtm+ow5HtGXU8iIhJJiUJERCIpUYiISKQD8oI7M9sMfLqPizUH0v2+F4qxaijGqvF9iBG+H3GmQ4xHuXuLeDMOyERRGWaWW95VielCMVYNxVg1vg8xwvcjznSPUV1PIiISSYlCREQiKVF8p9y756URxVg1FGPV+D7ECN+PONM6Rh2jEBGRSNqjEBGRSEoUIiISqcYnCjPra2YrzWyNmY1MdTwAZnakmb1tZsvNbKmZ3RiWH2JmM81sdfjcNA1izTCzhWb2Sjjd1szeD2N81szqpkGMB5vZ82a2ItymPdNtW5rZTeFn/ZGZTTGz+qnelmb2uJltMrOPYsribjcLjA//jxabWdcUxvi78LNebGbTzezgmHm/DGNcaWY/TlWMMfNuNjM3s+bhdEq2Y0VqdKIwswzgT0A/IAsYZGZZqY0KgCLgf9z9h0AP4LowrpHAm+7eAXgznE61G4HlMdP3AQ+EMX4JXJ2SqEr7A/C6ux8HnEAQb9psSzM7ArgByHH3TkAGcCmp35aTgL5lysrbbv2ADuFjKPDnFMY4E+jk7p2BVcAvAcL/oUuBjuEyD4ffAamIETM7EjgL+CymOFXbMVKNThRAd2CNu3/i7juBqcD5KY4Jd9/g7h+Gr78i+GI7giC2J8NqTwIXpCbCgJllAucCj4bTBpwBPB9WSYcYDwJOAx4DcPed7r6VNNuWBCM5NzCz2kBDYAMp3pbuPgf4okxxedvtfOCvHngPONjMWqYiRnd/w92Lwsn3gOIxtc8Hprr7t+7+b2ANwXdAtccYegC4BYg9oygl27EiNT1RHAGsi5nOD8vShpm1AboA7wOHufsGCJIJcGjqIgPgQYI/9D3hdDNga8w/aTpsz3bAZuCJsIvsUTNrRBptS3f/DzCO4JflBmAbsID025ZQ/nZL1/+lIcCM8HXaxGhmA4D/uPuiMrPSJsZYNT1RWJyytDlf2MwaA38DRrj7/6U6nlhm1h/Y5O4LYovjVE319qwNdAX+7O5dgO2kR5ddibCf/3ygLdAKaETQBVFWqrdllLT77M1sFEE37tPFRXGqVXuMZtYQGAXcEW92nLKUf+41PVHkA0fGTGcC61MUSylmVocgSTzt7i+ExZ8X74aGz5tSFR9wMjDAzNYSdNmdQbCHcXDYfQLpsT3zgXx3fz+cfp4gcaTTtuwD/NvdN7v7LuAF4CTSb1tC+dstrf6XzGww0B+43L+7WCxdYmxP8KNgUfj/kwl8aGaHkz4xllLTE8UHQIfw7JK6BAe6XkpxTMV9/Y8By939/phZLwGDw9eDgRerO7Zi7v5Ld8909zYE2+0td78ceBu4KKyW0hgB3H0jsM7Mjg2LzgSWkUbbkqDLqYeZNQw/++IY02pbhsrbbi8BPwvP2ukBbCvuoqpuZtYXuBUY4O47Yma9BFxqZvXMrC3BAeP51R2fuy9x90PdvU34/5MPdA3/VtNmO5bi7jX6AZxDcGbEx8CoVMcTxnQKwe7mYiAvfJxDcAzgTWB1+HxIqmMN4+0FvBK+bkfwz7cGmAbUS4P4soHccHv+HWiabtsSGAusAD4CJgP1Ur0tgSkEx0x2EXyZXV3ediPoMvlT+H+0hOAMrlTFuIagn7/4f2dCTP1RYYwrgX6pirHM/LVA81Rux4oeGsJDREQi1fSuJxERqYAShYiIRFKiEBGRSEoUIiISSYlCREQiKVFIjWZmbeKN6lmmziQzuyh8PSK8sraq1n9B7ECUZnanmfWpqvZFqoIShci+GUEwaF/CKhih9AKCkYsBcPc73H1WJWMTSQolCpGQmbULBw78UTnzbyAYi+ltM3s7LDvbzN41sw/NbFo4PhdmttbM7jCzucDFZvZfZvaBmS0ys7+FV2GfBAwAfmdmeWbWvszey5lhPEvCexrUi2l7bLjOJWZ2XFh+ethOXrhck6RvNKkRlChEgHCIj78BV7n7B/HquPt4gnF3ert77/BmM7cBfdy9K8HV37+IWaTQ3U9x96nAC+7+I3cvvh/G1e7+DsGQDf/r7tnu/nFMPPUJ7mNwibsfTzC44bCYtreE6/wzcHNYdjNwnbtnA6cC3+zPNhEppkQhAi0Ixiz6qbvn7cNyPQi6jeaZWR7B2EdHxcx/NuZ1JzP7l5ktAS4nuHlOlGMJBgpcFU4/SXBfjWLFA0UuANqEr+cB94d7Pgf7d0OUi+wXJQqR4P4P6whGxMXMngi7b16rYDkDZoZ7A9nunuXusXeh2x7zehJwfbh3MBaon0DbUb4Nn3cT7G3g7vcC1wANgPeKu6RE9lftiquIHPB2EhxU/oeZfe3uV0XU/QpoAmwhuHvan8zsaHdfE54NlRmzFxCrCbAhHD7+cuA/ZdorawXQprht4Argn1Fvwszau/sSYImZ9QSOC9sR2S/aoxAB3H07wf0LbjKzqNvhTgRmmNnb7r4ZuBKYYmaLCRJHeb/ibye4S+FMSn95TwX+Nzz43D4mnkLgKmBa2F21B5hQwdsYYWYfmdkiguMTMyqoL5IQjR4rIiKRtEchIiKRlChERCSSEoWIiERSohARkUhKFCIiEkmJQkREIilRiIhIpP8HIj/zIARfBjAAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "fig, ax = plt.subplots()\n", - "ax.plot(range(1,k), dx2_dz3_list, c = \"red\", linestyle = \"solid\", label = \"Model Gradient\")\n", - "\n", - "plt.axhline(y = pytorch_grad1, color = 'g', linestyle = '-', label = \"Pytorch Gradient-1\")\n", - "\n", - "plt.axhline(y = pytorch_grad2, color = 'blue', linestyle = '-', label = \"Pytorch Gradient-2\")\n", - "plt.legend(loc=\"lower right\")\n", - "plt.title(\"Gradient Plot for dx2_dz3\")\n", - "plt.ylabel(\"Gradient\")\n", - "plt.xlabel(\"k-Iterations\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3. Phys Model without Surrogate" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "--- Temporary TEST: Physics Example - without NOFAS\n", - "\n", - "--- Running on device: cpu\n", - "\n", - "\n", - "--- Running on device: cpu\n", - "\n", - "VI NF (t=1.000): it: 100 | loss: 9.842e+03\n", - "--- Saving results at iteration 200\n", - "VI NF (t=1.000): it: 200 | loss: 3.976e+03\n", - "VI NF (t=1.000): it: 300 | loss: 7.567e+02\n", - "--- Saving results at iteration 400\n", - "VI NF (t=1.000): it: 400 | loss: 1.483e+02\n", - "VI NF (t=1.000): it: 500 | loss: 1.947e+02\n", - "--- Saving results at iteration 600\n", - "VI NF (t=1.000): it: 600 | loss: 2.955e+02\n", - "VI NF (t=1.000): it: 700 | loss: 2.940e+02\n", - "--- Saving results at iteration 800\n", - "VI NF (t=1.000): it: 800 | loss: 1.907e+02\n", - "VI NF (t=1.000): it: 900 | loss: 1.403e+02\n", - "--- Saving results at iteration 1000\n", - "VI NF (t=1.000): it: 1000 | loss: 1.544e+02\n", - "VI NF (t=1.000): it: 1100 | loss: 1.082e+02\n", - "--- Saving results at iteration 1200\n", - "VI NF (t=1.000): it: 1200 | loss: 9.318e+01\n", - "VI NF (t=1.000): it: 1300 | loss: 1.706e+02\n", - "--- Saving results at iteration 1400\n", - "VI NF (t=1.000): it: 1400 | loss: 1.405e+02\n", - "VI NF (t=1.000): it: 1500 | loss: 1.443e+03\n", - "--- Saving results at iteration 1600\n", - "VI NF (t=1.000): it: 1600 | loss: 7.104e+01\n", - "VI NF (t=1.000): it: 1700 | loss: 7.039e+01\n", - "--- Saving results at iteration 1800\n", - "VI NF (t=1.000): it: 1800 | loss: 7.383e+01\n", - "VI NF (t=1.000): it: 1900 | loss: 7.144e+01\n", - "--- Saving results at iteration 2000\n", - "VI NF (t=1.000): it: 2000 | loss: 3.828e+01\n", - "VI NF (t=1.000): it: 2100 | loss: 5.490e+01\n", - "--- Saving results at iteration 2200\n", - "VI NF (t=1.000): it: 2200 | loss: 7.513e+01\n", - "VI NF (t=1.000): it: 2300 | loss: 4.001e+01\n", - "--- Saving results at iteration 2400\n", - "VI NF (t=1.000): it: 2400 | loss: 5.401e+01\n", - "VI NF (t=1.000): it: 2500 | loss: 6.711e+01\n", - "--- Saving results at iteration 2600\n", - "VI NF (t=1.000): it: 2600 | loss: 3.832e+01\n", - "VI NF (t=1.000): it: 2700 | loss: 3.001e+01\n", - "--- Saving results at iteration 2800\n", - "VI NF (t=1.000): it: 2800 | loss: 2.808e+01\n", - "VI NF (t=1.000): it: 2900 | loss: 2.579e+01\n", - "--- Saving results at iteration 3000\n", - "VI NF (t=1.000): it: 3000 | loss: 1.856e+01\n", - "VI NF (t=1.000): it: 3100 | loss: 1.719e+01\n", - "--- Saving results at iteration 3200\n", - "VI NF (t=1.000): it: 3200 | loss: 1.676e+01\n", - "VI NF (t=1.000): it: 3300 | loss: 1.639e+01\n", - "--- Saving results at iteration 3400\n", - "VI NF (t=1.000): it: 3400 | loss: 1.599e+01\n", - "VI NF (t=1.000): it: 3500 | loss: 1.647e+01\n", - "--- Saving results at iteration 3600\n", - "VI NF (t=1.000): it: 3600 | loss: 1.682e+01\n", - "VI NF (t=1.000): it: 3700 | loss: 1.697e+01\n", - "--- Saving results at iteration 3800\n", - "VI NF (t=1.000): it: 3800 | loss: 1.895e+01\n", - "VI NF (t=1.000): it: 3900 | loss: 1.579e+01\n", - "--- Saving results at iteration 4000\n", - "VI NF (t=1.000): it: 4000 | loss: 1.807e+01\n", - "VI NF (t=1.000): it: 4100 | loss: 1.734e+01\n", - "--- Saving results at iteration 4200\n", - "VI NF (t=1.000): it: 4200 | loss: 1.715e+01\n", - "VI NF (t=1.000): it: 4300 | loss: 1.836e+01\n", - "--- Saving results at iteration 4400\n", - "VI NF (t=1.000): it: 4400 | loss: 1.560e+01\n", - "VI NF (t=1.000): it: 4500 | loss: 1.536e+01\n", - "--- Saving results at iteration 4600\n", - "VI NF (t=1.000): it: 4600 | loss: 1.527e+01\n", - "VI NF (t=1.000): it: 4700 | loss: 1.557e+01\n", - "--- Saving results at iteration 4800\n", - "VI NF (t=1.000): it: 4800 | loss: 1.560e+01\n", - "VI NF (t=1.000): it: 4900 | loss: 1.621e+01\n", - "--- Saving results at iteration 5000\n", - "VI NF (t=1.000): it: 5000 | loss: 1.599e+01\n", - "VI NF (t=1.000): it: 5100 | loss: 1.559e+01\n", - "--- Saving results at iteration 5200\n", - "VI NF (t=1.000): it: 5200 | loss: 1.518e+01\n", - "VI NF (t=1.000): it: 5300 | loss: 1.533e+01\n", - "--- Saving results at iteration 5400\n", - "VI NF (t=1.000): it: 5400 | loss: 1.522e+01\n", - "VI NF (t=1.000): it: 5500 | loss: 1.515e+01\n", - "--- Saving results at iteration 5600\n", - "VI NF (t=1.000): it: 5600 | loss: 1.515e+01\n", - "VI NF (t=1.000): it: 5700 | loss: 1.502e+01\n", - "--- Saving results at iteration 5800\n", - "VI NF (t=1.000): it: 5800 | loss: 1.507e+01\n", - "VI NF (t=1.000): it: 5900 | loss: 1.505e+01\n", - "--- Saving results at iteration 6000\n", - "VI NF (t=1.000): it: 6000 | loss: 1.498e+01\n", - "VI NF (t=1.000): it: 6100 | loss: 1.486e+01\n", - "--- Saving results at iteration 6200\n", - "VI NF (t=1.000): it: 6200 | loss: 1.483e+01\n", - "VI NF (t=1.000): it: 6300 | loss: 1.482e+01\n", - "--- Saving results at iteration 6400\n", - "VI NF (t=1.000): it: 6400 | loss: 1.491e+01\n", - "VI NF (t=1.000): it: 6500 | loss: 1.488e+01\n", - "--- Saving results at iteration 6600\n", - "VI NF (t=1.000): it: 6600 | loss: 1.500e+01\n", - "VI NF (t=1.000): it: 6700 | loss: 1.481e+01\n", - "--- Saving results at iteration 6800\n", - "VI NF (t=1.000): it: 6800 | loss: 1.465e+01\n", - "VI NF (t=1.000): it: 6900 | loss: 1.478e+01\n", - "--- Saving results at iteration 7000\n", - "VI NF (t=1.000): it: 7000 | loss: 1.475e+01\n", - "VI NF (t=1.000): it: 7100 | loss: 1.473e+01\n", - "--- Saving results at iteration 7200\n", - "VI NF (t=1.000): it: 7200 | loss: 1.492e+01\n", - "VI NF (t=1.000): it: 7300 | loss: 1.492e+01\n", - "--- Saving results at iteration 7400\n", - "VI NF (t=1.000): it: 7400 | loss: 1.484e+01\n", - "VI NF (t=1.000): it: 7500 | loss: 1.465e+01\n", - "--- Saving results at iteration 7600\n", - "VI NF (t=1.000): it: 7600 | loss: 1.463e+01\n", - "VI NF (t=1.000): it: 7700 | loss: 1.478e+01\n", - "--- Saving results at iteration 7800\n", - "VI NF (t=1.000): it: 7800 | loss: 1.460e+01\n", - "VI NF (t=1.000): it: 7900 | loss: 1.459e+01\n", - "--- Saving results at iteration 8000\n", - "VI NF (t=1.000): it: 8000 | loss: 1.477e+01\n", - "VI NF (t=1.000): it: 8100 | loss: 1.654e+01\n", - "--- Saving results at iteration 8200\n", - "VI NF (t=1.000): it: 8200 | loss: 1.454e+01\n", - "VI NF (t=1.000): it: 8300 | loss: 1.462e+01\n", - "--- Saving results at iteration 8400\n", - "VI NF (t=1.000): it: 8400 | loss: 1.456e+01\n", - "VI NF (t=1.000): it: 8500 | loss: 1.448e+01\n", - "--- Saving results at iteration 8600\n", - "VI NF (t=1.000): it: 8600 | loss: 1.449e+01\n", - "VI NF (t=1.000): it: 8700 | loss: 1.469e+01\n", - "--- Saving results at iteration 8800\n", - "VI NF (t=1.000): it: 8800 | loss: 1.451e+01\n", - "VI NF (t=1.000): it: 8900 | loss: 1.454e+01\n", - "--- Saving results at iteration 9000\n", - "VI NF (t=1.000): it: 9000 | loss: 1.440e+01\n", - "VI NF (t=1.000): it: 9100 | loss: 1.450e+01\n", - "--- Saving results at iteration 9200\n", - "VI NF (t=1.000): it: 9200 | loss: 1.431e+01\n", - "VI NF (t=1.000): it: 9300 | loss: 1.434e+01\n", - "--- Saving results at iteration 9400\n", - "VI NF (t=1.000): it: 9400 | loss: 1.423e+01\n", - "VI NF (t=1.000): it: 9500 | loss: 1.437e+01\n", - "--- Saving results at iteration 9600\n", - "VI NF (t=1.000): it: 9600 | loss: 1.428e+01\n", - "VI NF (t=1.000): it: 9700 | loss: 1.439e+01\n", - "--- Saving results at iteration 9800\n", - "VI NF (t=1.000): it: 9800 | loss: 1.446e+01\n", - "VI NF (t=1.000): it: 9900 | loss: 1.422e+01\n", - "--- Saving results at iteration 10000\n", - "VI NF (t=1.000): it: 10000 | loss: 1.426e+01\n", - "VI NF (t=1.000): it: 10100 | loss: 1.417e+01\n", - "--- Saving results at iteration 10200\n", - "VI NF (t=1.000): it: 10200 | loss: 1.434e+01\n", - "VI NF (t=1.000): it: 10300 | loss: 1.434e+01\n", - "--- Saving results at iteration 10400\n", - "VI NF (t=1.000): it: 10400 | loss: 1.415e+01\n", - "VI NF (t=1.000): it: 10500 | loss: 1.431e+01\n", - "--- Saving results at iteration 10600\n", - "VI NF (t=1.000): it: 10600 | loss: 1.412e+01\n", - "VI NF (t=1.000): it: 10700 | loss: 1.410e+01\n", - "--- Saving results at iteration 10800\n", - "VI NF (t=1.000): it: 10800 | loss: 1.403e+01\n", - "VI NF (t=1.000): it: 10900 | loss: 1.416e+01\n", - "--- Saving results at iteration 11000\n", - "VI NF (t=1.000): it: 11000 | loss: 1.429e+01\n", - "VI NF (t=1.000): it: 11100 | loss: 1.409e+01\n", - "--- Saving results at iteration 11200\n", - "VI NF (t=1.000): it: 11200 | loss: 1.409e+01\n", - "VI NF (t=1.000): it: 11300 | loss: 1.395e+01\n", - "--- Saving results at iteration 11400\n", - "VI NF (t=1.000): it: 11400 | loss: 1.384e+01\n", - "VI NF (t=1.000): it: 11500 | loss: 1.384e+01\n", - "--- Saving results at iteration 11600\n", - "VI NF (t=1.000): it: 11600 | loss: 1.381e+01\n", - "VI NF (t=1.000): it: 11700 | loss: 1.371e+01\n", - "--- Saving results at iteration 11800\n", - "VI NF (t=1.000): it: 11800 | loss: 1.372e+01\n", - "VI NF (t=1.000): it: 11900 | loss: 1.354e+01\n", - "--- Saving results at iteration 12000\n", - "VI NF (t=1.000): it: 12000 | loss: 1.354e+01\n", - "VI NF (t=1.000): it: 12100 | loss: 1.350e+01\n", - "--- Saving results at iteration 12200\n", - "VI NF (t=1.000): it: 12200 | loss: 1.381e+01\n", - "VI NF (t=1.000): it: 12300 | loss: 1.349e+01\n", - "--- Saving results at iteration 12400\n", - "VI NF (t=1.000): it: 12400 | loss: 1.369e+01\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "VI NF (t=1.000): it: 12500 | loss: 1.344e+01\n", - "--- Saving results at iteration 12600\n", - "VI NF (t=1.000): it: 12600 | loss: 1.367e+01\n", - "VI NF (t=1.000): it: 12700 | loss: 1.331e+01\n", - "--- Saving results at iteration 12800\n", - "VI NF (t=1.000): it: 12800 | loss: 1.330e+01\n", - "VI NF (t=1.000): it: 12900 | loss: 1.331e+01\n", - "--- Saving results at iteration 13000\n", - "VI NF (t=1.000): it: 13000 | loss: 1.315e+01\n", - "VI NF (t=1.000): it: 13100 | loss: 1.328e+01\n", - "--- Saving results at iteration 13200\n", - "VI NF (t=1.000): it: 13200 | loss: 1.322e+01\n", - "VI NF (t=1.000): it: 13300 | loss: 1.316e+01\n", - "--- Saving results at iteration 13400\n", - "VI NF (t=1.000): it: 13400 | loss: 1.316e+01\n", - "VI NF (t=1.000): it: 13500 | loss: 1.306e+01\n", - "--- Saving results at iteration 13600\n", - "VI NF (t=1.000): it: 13600 | loss: 1.339e+01\n", - "VI NF (t=1.000): it: 13700 | loss: 1.365e+01\n", - "--- Saving results at iteration 13800\n", - "VI NF (t=1.000): it: 13800 | loss: 1.296e+01\n", - "VI NF (t=1.000): it: 13900 | loss: 1.326e+01\n", - "--- Saving results at iteration 14000\n", - "VI NF (t=1.000): it: 14000 | loss: 1.304e+01\n", - "VI NF (t=1.000): it: 14100 | loss: 1.296e+01\n", - "--- Saving results at iteration 14200\n", - "VI NF (t=1.000): it: 14200 | loss: 1.300e+01\n", - "VI NF (t=1.000): it: 14300 | loss: 1.278e+01\n", - "--- Saving results at iteration 14400\n", - "VI NF (t=1.000): it: 14400 | loss: 1.319e+01\n", - "VI NF (t=1.000): it: 14500 | loss: 1.287e+01\n", - "--- Saving results at iteration 14600\n", - "VI NF (t=1.000): it: 14600 | loss: 1.300e+01\n", - "VI NF (t=1.000): it: 14700 | loss: 1.284e+01\n", - "--- Saving results at iteration 14800\n", - "VI NF (t=1.000): it: 14800 | loss: 1.295e+01\n", - "VI NF (t=1.000): it: 14900 | loss: 1.281e+01\n", - "--- Saving results at iteration 15000\n", - "VI NF (t=1.000): it: 15000 | loss: 1.306e+01\n", - "VI NF (t=1.000): it: 15100 | loss: 1.321e+01\n", - "--- Saving results at iteration 15200\n", - "VI NF (t=1.000): it: 15200 | loss: 1.287e+01\n", - "VI NF (t=1.000): it: 15300 | loss: 1.277e+01\n", - "--- Saving results at iteration 15400\n", - "VI NF (t=1.000): it: 15400 | loss: 1.274e+01\n", - "VI NF (t=1.000): it: 15500 | loss: 1.297e+01\n", - "--- Saving results at iteration 15600\n", - "VI NF (t=1.000): it: 15600 | loss: 1.300e+01\n", - "VI NF (t=1.000): it: 15700 | loss: 1.307e+01\n", - "--- Saving results at iteration 15800\n", - "VI NF (t=1.000): it: 15800 | loss: 1.288e+01\n", - "VI NF (t=1.000): it: 15900 | loss: 1.286e+01\n", - "--- Saving results at iteration 16000\n", - "VI NF (t=1.000): it: 16000 | loss: 1.291e+01\n", - "VI NF (t=1.000): it: 16100 | loss: 1.272e+01\n", - "--- Saving results at iteration 16200\n", - "VI NF (t=1.000): it: 16200 | loss: 1.286e+01\n", - "VI NF (t=1.000): it: 16300 | loss: 1.275e+01\n", - "--- Saving results at iteration 16400\n", - "VI NF (t=1.000): it: 16400 | loss: 1.317e+01\n", - "VI NF (t=1.000): it: 16500 | loss: 1.292e+01\n", - "--- Saving results at iteration 16600\n", - "VI NF (t=1.000): it: 16600 | loss: 1.276e+01\n", - "VI NF (t=1.000): it: 16700 | loss: 1.324e+01\n", - "--- Saving results at iteration 16800\n", - "VI NF (t=1.000): it: 16800 | loss: 1.278e+01\n", - "VI NF (t=1.000): it: 16900 | loss: 1.273e+01\n", - "--- Saving results at iteration 17000\n", - "VI NF (t=1.000): it: 17000 | loss: 1.280e+01\n", - "VI NF (t=1.000): it: 17100 | loss: 1.282e+01\n", - "--- Saving results at iteration 17200\n", - "VI NF (t=1.000): it: 17200 | loss: 1.297e+01\n", - "VI NF (t=1.000): it: 17300 | loss: 1.281e+01\n", - "--- Saving results at iteration 17400\n", - "VI NF (t=1.000): it: 17400 | loss: 1.306e+01\n", - "VI NF (t=1.000): it: 17500 | loss: 1.292e+01\n", - "--- Saving results at iteration 17600\n", - "VI NF (t=1.000): it: 17600 | loss: 1.274e+01\n", - "VI NF (t=1.000): it: 17700 | loss: 1.407e+01\n", - "--- Saving results at iteration 17800\n", - "VI NF (t=1.000): it: 17800 | loss: 1.272e+01\n", - "VI NF (t=1.000): it: 17900 | loss: 1.296e+01\n", - "--- Saving results at iteration 18000\n", - "VI NF (t=1.000): it: 18000 | loss: 1.271e+01\n", - "VI NF (t=1.000): it: 18100 | loss: 1.276e+01\n", - "--- Saving results at iteration 18200\n", - "VI NF (t=1.000): it: 18200 | loss: 1.300e+01\n", - "VI NF (t=1.000): it: 18300 | loss: 1.262e+01\n", - "--- Saving results at iteration 18400\n", - "VI NF (t=1.000): it: 18400 | loss: 1.317e+01\n", - "VI NF (t=1.000): it: 18500 | loss: 1.283e+01\n", - "--- Saving results at iteration 18600\n", - "VI NF (t=1.000): it: 18600 | loss: 1.270e+01\n", - "VI NF (t=1.000): it: 18700 | loss: 1.307e+01\n", - "--- Saving results at iteration 18800\n", - "VI NF (t=1.000): it: 18800 | loss: 1.280e+01\n", - "VI NF (t=1.000): it: 18900 | loss: 1.267e+01\n", - "--- Saving results at iteration 19000\n", - "VI NF (t=1.000): it: 19000 | loss: 1.301e+01\n", - "VI NF (t=1.000): it: 19100 | loss: 1.273e+01\n", - "--- Saving results at iteration 19200\n", - "VI NF (t=1.000): it: 19200 | loss: 1.313e+01\n", - "VI NF (t=1.000): it: 19300 | loss: 1.281e+01\n", - "--- Saving results at iteration 19400\n", - "VI NF (t=1.000): it: 19400 | loss: 1.270e+01\n", - "VI NF (t=1.000): it: 19500 | loss: 1.286e+01\n", - "--- Saving results at iteration 19600\n", - "VI NF (t=1.000): it: 19600 | loss: 1.270e+01\n", - "VI NF (t=1.000): it: 19700 | loss: 1.291e+01\n", - "--- Saving results at iteration 19800\n", - "VI NF (t=1.000): it: 19800 | loss: 1.302e+01\n", - "VI NF (t=1.000): it: 19900 | loss: 1.279e+01\n", - "--- Saving results at iteration 20000\n", - "VI NF (t=1.000): it: 20000 | loss: 1.316e+01\n", - "VI NF (t=1.000): it: 20100 | loss: 1.274e+01\n", - "--- Saving results at iteration 20200\n", - "VI NF (t=1.000): it: 20200 | loss: 1.270e+01\n", - "VI NF (t=1.000): it: 20300 | loss: 1.270e+01\n", - "--- Saving results at iteration 20400\n", - "VI NF (t=1.000): it: 20400 | loss: 1.256e+01\n", - "VI NF (t=1.000): it: 20500 | loss: 1.276e+01\n", - "--- Saving results at iteration 20600\n", - "VI NF (t=1.000): it: 20600 | loss: 1.273e+01\n", - "VI NF (t=1.000): it: 20700 | loss: 1.308e+01\n", - "--- Saving results at iteration 20800\n", - "VI NF (t=1.000): it: 20800 | loss: 1.295e+01\n", - "VI NF (t=1.000): it: 20900 | loss: 1.314e+01\n", - "--- Saving results at iteration 21000\n", - "VI NF (t=1.000): it: 21000 | loss: 1.274e+01\n", - "VI NF (t=1.000): it: 21100 | loss: 1.266e+01\n", - "--- Saving results at iteration 21200\n", - "VI NF (t=1.000): it: 21200 | loss: 1.281e+01\n", - "VI NF (t=1.000): it: 21300 | loss: 1.271e+01\n", - "--- Saving results at iteration 21400\n", - "VI NF (t=1.000): it: 21400 | loss: 1.356e+01\n", - "VI NF (t=1.000): it: 21500 | loss: 1.270e+01\n", - "--- Saving results at iteration 21600\n", - "VI NF (t=1.000): it: 21600 | loss: 1.292e+01\n", - "VI NF (t=1.000): it: 21700 | loss: 1.270e+01\n", - "--- Saving results at iteration 21800\n", - "VI NF (t=1.000): it: 21800 | loss: 1.263e+01\n", - "VI NF (t=1.000): it: 21900 | loss: 1.280e+01\n", - "--- Saving results at iteration 22000\n", - "VI NF (t=1.000): it: 22000 | loss: 1.273e+01\n", - "VI NF (t=1.000): it: 22100 | loss: 1.279e+01\n", - "--- Saving results at iteration 22200\n", - "VI NF (t=1.000): it: 22200 | loss: 1.269e+01\n", - "VI NF (t=1.000): it: 22300 | loss: 1.265e+01\n", - "--- Saving results at iteration 22400\n", - "VI NF (t=1.000): it: 22400 | loss: 1.279e+01\n", - "VI NF (t=1.000): it: 22500 | loss: 1.260e+01\n", - "--- Saving results at iteration 22600\n", - "VI NF (t=1.000): it: 22600 | loss: 1.285e+01\n", - "VI NF (t=1.000): it: 22700 | loss: 1.261e+01\n", - "--- Saving results at iteration 22800\n", - "VI NF (t=1.000): it: 22800 | loss: 1.271e+01\n", - "VI NF (t=1.000): it: 22900 | loss: 1.334e+01\n", - "--- Saving results at iteration 23000\n", - "VI NF (t=1.000): it: 23000 | loss: 1.261e+01\n", - "VI NF (t=1.000): it: 23100 | loss: 1.301e+01\n", - "--- Saving results at iteration 23200\n", - "VI NF (t=1.000): it: 23200 | loss: 1.341e+01\n", - "VI NF (t=1.000): it: 23300 | loss: 1.275e+01\n", - "--- Saving results at iteration 23400\n", - "VI NF (t=1.000): it: 23400 | loss: 1.259e+01\n", - "VI NF (t=1.000): it: 23500 | loss: 1.290e+01\n", - "--- Saving results at iteration 23600\n", - "VI NF (t=1.000): it: 23600 | loss: 1.248e+01\n", - "VI NF (t=1.000): it: 23700 | loss: 1.270e+01\n", - "--- Saving results at iteration 23800\n", - "VI NF (t=1.000): it: 23800 | loss: 1.370e+01\n", - "VI NF (t=1.000): it: 23900 | loss: 1.259e+01\n", - "--- Saving results at iteration 24000\n", - "VI NF (t=1.000): it: 24000 | loss: 1.272e+01\n", - "VI NF (t=1.000): it: 24100 | loss: 1.295e+01\n", - "--- Saving results at iteration 24200\n", - "VI NF (t=1.000): it: 24200 | loss: 1.261e+01\n", - "VI NF (t=1.000): it: 24300 | loss: 1.259e+01\n", - "--- Saving results at iteration 24400\n", - "VI NF (t=1.000): it: 24400 | loss: 1.287e+01\n", - "VI NF (t=1.000): it: 24500 | loss: 1.251e+01\n", - "--- Saving results at iteration 24600\n", - "VI NF (t=1.000): it: 24600 | loss: 1.266e+01\n", - "VI NF (t=1.000): it: 24700 | loss: 1.287e+01\n", - "--- Saving results at iteration 24800\n", - "VI NF (t=1.000): it: 24800 | loss: 1.254e+01\n", - "VI NF (t=1.000): it: 24900 | loss: 1.288e+01\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--- Saving results at iteration 25000\n", - "VI NF (t=1.000): it: 25000 | loss: 1.271e+01\n", - "\n", - "--- Simulation completed!\n" - ] - } - ], - "source": [ - "print('')\n", - "print('--- Temporary TEST: Physics Example - without NOFAS')\n", - "print('')\n", - "\n", - "# Experiment Setting\n", - "exp = experiment()\n", - "exp.name = \"phys_nofasFree\" # str: Name of experiment\n", - "exp.flow_type = 'maf' # str: Type of flow\n", - "exp.n_blocks = 5 # int: Number of layers \n", - "exp.hidden_size = 100 # int: Hidden layer size for MADE in each layer \n", - "exp.n_hidden = 1 # int: Number of hidden layers in each MADE \n", - "exp.activation_fn = 'relu' # str: Actication function used \n", - "exp.input_order = 'sequential' # str: Input order for create_mask \n", - "exp.batch_norm_order = True # boolean: Order to decide if batch_norm is used \n", - "exp.sampling_interval = 5000 # int: How often to sample from normalizing flow\n", - "\n", - "exp.input_size = 3 # int: Dimensionality of input \n", - "exp.batch_size = 250 # int: Number of samples generated \n", - "exp.true_data_num = 2 # double: number of true model evaluated \n", - "exp.n_iter = 25001 # int: Number of iterations \n", - "exp.lr = 0.01 # float: Learning rate \n", - "exp.lr_decay = 0.9999 # float: Learning rate decay \n", - "exp.log_interval = 100 # int: How often to show loss stat \n", - "\n", - "exp.run_nofas = False # boolean: to run experiment with nofas\n", - "exp.annealing = False # boolean: to run experiment with annealing\n", - "exp.calibrate_interval = 1000 # int: How often to update surrogate model \n", - "exp.budget = 260 # int: Total number of true model evaluation\n", - "\n", - "exp.surr_pre_it = 20000 # int: Number of pre-training iterations for surrogate model\n", - "exp.surr_upd_it = 6000 # int: Number of iterations for the surrogate model update\n", - "exp.surr_folder = \"./\"\n", - "exp.use_new_surr = True # boolean: to run experiment with nofas\n", - "\n", - "exp.output_dir = './' + exp.name # str: output directory location\n", - "exp.results_file = 'results.txt' # str: result text file name\n", - "exp.log_file = 'log.txt' # str: log text file name\n", - "exp.samples_file = 'samples.txt' # str: sample text file name\n", - "exp.seed = random.randint(0, 10 ** 9) # int: Random seed used\n", - "exp.n_sample = 5000 # int: Total number of iterations\n", - "exp.no_cuda = True # boolean: to run experiment with NO cuda\n", - "\n", - "exp.optimizer = 'RMSprop' # str: Type of optimizer\n", - "exp.lr_scheduler = 'ExponentialLR' # str: Type of scheduler\n", - "\n", - "exp.device = torch.device('cuda:0' if torch.cuda.is_available() and not exp.no_cuda else 'cpu')\n", - "\n", - "print('--- Running on device: '+ str(exp.device))\n", - "print('')\n", - "\n", - "# Define transformation based on normalization rate\n", - "trsf_info = [['identity',0.0,0.0,0.0,0.0],\n", - " ['identity',0.0,0.0,0.0,0.0],\n", - " ['linear',-3,3,30.0,80.0]]\n", - "trsf = Transformation(trsf_info) \n", - "exp.transform = trsf\n", - "\n", - "# Define model\n", - "exp.model = model\n", - "\n", - "# Get data\n", - "model.data = np.loadtxt('./data_phys.txt')\n", - "\n", - "# Run experiment without surrogate\n", - "exp.surrogate = None\n", - "\n", - "## Define log density\n", - "# x: original, untransformed inputs\n", - "# model: our model\n", - "# transform: our transformation \n", - "def log_density(x, model, transform):\n", - " # Compute transformation log Jacobian\n", - " adjust = transform.compute_log_jacob_func(x)\n", - "\n", - " batch_size = x.size(0)\n", - " # Get the absolute values of the standard deviations\n", - " stds = torch.abs(model.solve_t(model.defParam)) * model.stdRatio\n", - " Data = torch.tensor(model.data)\n", - "\n", - " # Get model output without surrogate\n", - " modelOut = model.solve_t(transform.forward(x))\n", - "\n", - " # Eval LL\n", - " ll1 = -0.5 * np.prod(model.data.shape) * np.log(2.0 * np.pi)\n", - " ll2 = (-0.5 * model.data.shape[1] * torch.log(torch.prod(stds))).item()\n", - " ll3 = 0.0\n", - " for i in range(3):\n", - " ll3 += - 0.5 * torch.sum(((modelOut[:, i].unsqueeze(1) - Data[i, :].unsqueeze(0)) / stds[0, i]) ** 2, dim=1)\n", - " negLL = -(ll1 + ll2 + ll3)\n", - " res = -negLL.reshape(x.size(0), 1) + adjust\n", - " return res\n", - "\n", - "# Assign logdensity\n", - "exp.model_logdensity = lambda x: log_density(x, model, trsf)\n", - "\n", - "# Run VI\n", - "exp.run()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4. Phys Model with Surrogate" - ] - }, - { - "cell_type": "code", - "execution_count": 94, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "--- Temporary TEST: Physics Example - NOFAS\n", - "\n", - "--- Running on device: cpu\n", - "\n", - "Success: Pre-Grid found.\n", - "Warning: Surrogate model files: ./phys.npz and ./phys.npz could not be found. \n", - "\n", - "--- Pre-training surrogate model\n", - "\n", - "SUR: PRE: it: 0 | loss: 2.878e+00\n", - "SUR: PRE: it: 500 | loss: 2.143e-02\n", - "SUR: PRE: it: 1000 | loss: 3.403e-02\n", - "SUR: PRE: it: 1500 | loss: 6.723e-02\n", - "SUR: PRE: it: 2000 | loss: 4.098e-02\n", - "SUR: PRE: it: 2500 | loss: 2.187e-02\n", - "SUR: PRE: it: 3000 | loss: 4.471e-02\n", - "SUR: PRE: it: 3500 | loss: 1.801e-02\n", - "SUR: PRE: it: 4000 | loss: 1.560e-02\n", - "SUR: PRE: it: 4500 | loss: 1.199e-02\n", - "SUR: PRE: it: 5000 | loss: 1.182e-02\n", - "SUR: PRE: it: 5500 | loss: 4.405e-03\n", - "SUR: PRE: it: 6000 | loss: 1.455e-02\n", - "SUR: PRE: it: 6500 | loss: 7.392e-03\n", - "SUR: PRE: it: 7000 | loss: 3.385e-03\n", - "SUR: PRE: it: 7500 | loss: 1.512e-02\n", - "SUR: PRE: it: 8000 | loss: 7.596e-03\n", - "SUR: PRE: it: 8500 | loss: 3.101e-03\n", - "SUR: PRE: it: 9000 | loss: 9.216e-03\n", - "SUR: PRE: it: 9500 | loss: 7.438e-03\n", - "SUR: PRE: it: 10000 | loss: 3.331e-03\n", - "SUR: PRE: it: 10500 | loss: 9.847e-03\n", - "SUR: PRE: it: 11000 | loss: 4.369e-03\n", - "SUR: PRE: it: 11500 | loss: 5.205e-03\n", - "SUR: PRE: it: 12000 | loss: 3.043e-03\n", - "SUR: PRE: it: 12500 | loss: 7.762e-03\n", - "SUR: PRE: it: 13000 | loss: 4.609e-03\n", - "SUR: PRE: it: 13500 | loss: 1.927e-03\n", - "SUR: PRE: it: 14000 | loss: 2.427e-03\n", - "SUR: PRE: it: 14500 | loss: 2.380e-03\n", - "SUR: PRE: it: 15000 | loss: 2.590e-04\n", - "SUR: PRE: it: 15500 | loss: 6.610e-04\n", - "SUR: PRE: it: 16000 | loss: 2.573e-03\n", - "SUR: PRE: it: 16500 | loss: 1.311e-03\n", - "SUR: PRE: it: 17000 | loss: 1.291e-03\n", - "SUR: PRE: it: 17500 | loss: 3.512e-04\n", - "SUR: PRE: it: 18000 | loss: 5.945e-04\n", - "SUR: PRE: it: 18500 | loss: 1.682e-03\n", - "SUR: PRE: it: 19000 | loss: 3.748e-04\n", - "SUR: PRE: it: 19500 | loss: 6.141e-04\n", - "\n", - "--- Surrogate model pre-train complete\n", - "\n", - "Success: [limits] loaded.\n", - "Success: [pre_grid] loaded.\n", - "Success: [grid_record] loaded.\n", - "\n", - "--- Running on device: cpu\n", - "\n", - "VI NF (t=1.000): it: 100 | loss: 8.587e+03\n", - "--- Saving results at iteration 200\n", - "VI NF (t=1.000): it: 200 | loss: 4.078e+03\n", - "VI NF (t=1.000): it: 300 | loss: 1.863e+03\n", - "--- Saving results at iteration 400\n", - "VI NF (t=1.000): it: 400 | loss: 7.934e+02\n", - "VI NF (t=1.000): it: 500 | loss: 2.534e+02\n", - "--- Saving results at iteration 600\n", - "VI NF (t=1.000): it: 600 | loss: 1.288e+02\n", - "VI NF (t=1.000): it: 700 | loss: 1.219e+02\n", - "--- Saving results at iteration 800\n", - "VI NF (t=1.000): it: 800 | loss: 1.233e+02\n", - "VI NF (t=1.000): it: 900 | loss: 1.207e+02\n", - "--- Saving results at iteration 1000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "1.312e-02 -> 1.470e-01\n", - "7.577e-04 -> 2.723e-01\n", - "1.694e-02 -> 9.900e-02\n", - "\n", - "SUR: UPD: it: 0 | loss: 5.734e-02\n", - "SUR: UPD: it: 500 | loss: 3.967e-03\n", - "SUR: UPD: it: 1000 | loss: 2.285e-03\n", - "SUR: UPD: it: 1500 | loss: 1.289e-03\n", - "SUR: UPD: it: 2000 | loss: 1.214e-03\n", - "SUR: UPD: it: 2500 | loss: 1.198e-03\n", - "SUR: UPD: it: 3000 | loss: 1.175e-03\n", - "SUR: UPD: it: 3500 | loss: 1.150e-03\n", - "SUR: UPD: it: 4000 | loss: 1.146e-03\n", - "SUR: UPD: it: 4500 | loss: 1.143e-03\n", - "SUR: UPD: it: 5000 | loss: 1.141e-03\n", - "SUR: UPD: it: 5500 | loss: 1.140e-03\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 1000 | loss: 1.340e+03\n", - "VI NF (t=1.000): it: 1100 | loss: 2.904e+02\n", - "--- Saving results at iteration 1200\n", - "VI NF (t=1.000): it: 1200 | loss: 7.772e+01\n", - "VI NF (t=1.000): it: 1300 | loss: 8.895e+01\n", - "--- Saving results at iteration 1400\n", - "VI NF (t=1.000): it: 1400 | loss: 8.059e+01\n", - "VI NF (t=1.000): it: 1500 | loss: 6.278e+01\n", - "--- Saving results at iteration 1600\n", - "VI NF (t=1.000): it: 1600 | loss: 6.714e+01\n", - "VI NF (t=1.000): it: 1700 | loss: 1.369e+02\n", - "--- Saving results at iteration 1800\n", - "VI NF (t=1.000): it: 1800 | loss: 7.575e+01\n", - "VI NF (t=1.000): it: 1900 | loss: 4.249e+01\n", - "--- Saving results at iteration 2000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "3.534e-04 -> 1.154e-01\n", - "4.653e-06 -> 4.538e-02\n", - "2.291e-04 -> 2.527e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 1.482e-02\n", - "SUR: UPD: it: 500 | loss: 6.411e-03\n", - "SUR: UPD: it: 1000 | loss: 3.761e-03\n", - "SUR: UPD: it: 1500 | loss: 2.164e-03\n", - "SUR: UPD: it: 2000 | loss: 2.132e-03\n", - "SUR: UPD: it: 2500 | loss: 2.075e-03\n", - "SUR: UPD: it: 3000 | loss: 2.084e-03\n", - "SUR: UPD: it: 3500 | loss: 2.040e-03\n", - "SUR: UPD: it: 4000 | loss: 1.992e-03\n", - "SUR: UPD: it: 4500 | loss: 1.971e-03\n", - "SUR: UPD: it: 5000 | loss: 1.965e-03\n", - "SUR: UPD: it: 5500 | loss: 1.962e-03\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 2000 | loss: 9.921e+02\n", - "VI NF (t=1.000): it: 2100 | loss: 2.138e+02\n", - "--- Saving results at iteration 2200\n", - "VI NF (t=1.000): it: 2200 | loss: 2.356e+02\n", - "VI NF (t=1.000): it: 2300 | loss: 1.712e+02\n", - "--- Saving results at iteration 2400\n", - "VI NF (t=1.000): it: 2400 | loss: 1.678e+02\n", - "VI NF (t=1.000): it: 2500 | loss: 2.024e+02\n", - "--- Saving results at iteration 2600\n", - "VI NF (t=1.000): it: 2600 | loss: 1.824e+02\n", - "VI NF (t=1.000): it: 2700 | loss: 1.664e+02\n", - "--- Saving results at iteration 2800\n", - "VI NF (t=1.000): it: 2800 | loss: 1.543e+02\n", - "VI NF (t=1.000): it: 2900 | loss: 1.526e+02\n", - "--- Saving results at iteration 3000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "2.150e-05 -> 6.847e-02\n", - "5.379e-04 -> 4.192e-02\n", - "2.559e-04 -> 3.516e-02\n", - "\n", - "SUR: UPD: it: 0 | loss: 8.240e-03\n", - "SUR: UPD: it: 500 | loss: 1.114e-02\n", - "SUR: UPD: it: 1000 | loss: 6.015e-03\n", - "SUR: UPD: it: 1500 | loss: 6.128e-03\n", - "SUR: UPD: it: 2000 | loss: 4.154e-03\n", - "SUR: UPD: it: 2500 | loss: 3.670e-03\n", - "SUR: UPD: it: 3000 | loss: 3.605e-03\n", - "SUR: UPD: it: 3500 | loss: 3.515e-03\n", - "SUR: UPD: it: 4000 | loss: 3.454e-03\n", - "SUR: UPD: it: 4500 | loss: 3.414e-03\n", - "SUR: UPD: it: 5000 | loss: 3.387e-03\n", - "SUR: UPD: it: 5500 | loss: 3.363e-03\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 3000 | loss: 1.065e+02\n", - "VI NF (t=1.000): it: 3100 | loss: 8.990e+01\n", - "--- Saving results at iteration 3200\n", - "VI NF (t=1.000): it: 3200 | loss: 7.114e+01\n", - "VI NF (t=1.000): it: 3300 | loss: 8.242e+01\n", - "--- Saving results at iteration 3400\n", - "VI NF (t=1.000): it: 3400 | loss: 7.382e+01\n", - "VI NF (t=1.000): it: 3500 | loss: 5.745e+01\n", - "--- Saving results at iteration 3600\n", - "VI NF (t=1.000): it: 3600 | loss: 5.898e+01\n", - "VI NF (t=1.000): it: 3700 | loss: 5.598e+01\n", - "--- Saving results at iteration 3800\n", - "VI NF (t=1.000): it: 3800 | loss: 5.288e+01\n", - "VI NF (t=1.000): it: 3900 | loss: 5.871e+01\n", - "--- Saving results at iteration 4000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "1.008e-01 -> 1.008e-01\n", - "6.865e-02 -> 9.545e-02\n", - "1.079e-01 -> 1.079e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 1.303e-02\n", - "SUR: UPD: it: 500 | loss: 1.756e-02\n", - "SUR: UPD: it: 1000 | loss: 9.240e-03\n", - "SUR: UPD: it: 1500 | loss: 5.613e-03\n", - "SUR: UPD: it: 2000 | loss: 5.445e-03\n", - "SUR: UPD: it: 2500 | loss: 5.325e-03\n", - "SUR: UPD: it: 3000 | loss: 5.330e-03\n", - "SUR: UPD: it: 3500 | loss: 5.140e-03\n", - "SUR: UPD: it: 4000 | loss: 5.057e-03\n", - "SUR: UPD: it: 4500 | loss: 4.965e-03\n", - "SUR: UPD: it: 5000 | loss: 4.874e-03\n", - "SUR: UPD: it: 5500 | loss: 4.807e-03\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 4000 | loss: 1.266e+02\n", - "VI NF (t=1.000): it: 4100 | loss: 1.235e+02\n", - "--- Saving results at iteration 4200\n", - "VI NF (t=1.000): it: 4200 | loss: 1.223e+02\n", - "VI NF (t=1.000): it: 4300 | loss: 1.105e+02\n", - "--- Saving results at iteration 4400\n", - "VI NF (t=1.000): it: 4400 | loss: 1.106e+02\n", - "VI NF (t=1.000): it: 4500 | loss: 1.129e+02\n", - "--- Saving results at iteration 4600\n", - "VI NF (t=1.000): it: 4600 | loss: 1.165e+02\n", - "VI NF (t=1.000): it: 4700 | loss: 1.272e+02\n", - "--- Saving results at iteration 4800\n", - "VI NF (t=1.000): it: 4800 | loss: 1.189e+02\n", - "VI NF (t=1.000): it: 4900 | loss: 1.115e+02\n", - "--- Saving results at iteration 5000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "3.044e-01 -> 3.044e-01\n", - "2.927e-01 -> 2.927e-01\n", - "4.060e-01 -> 4.060e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 4.900e-02\n", - "SUR: UPD: it: 500 | loss: 5.751e-02\n", - "SUR: UPD: it: 1000 | loss: 4.214e-02\n", - "SUR: UPD: it: 1500 | loss: 4.952e-02\n", - "SUR: UPD: it: 2000 | loss: 3.151e-02\n", - "SUR: UPD: it: 2500 | loss: 2.883e-02\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SUR: UPD: it: 3000 | loss: 2.828e-02\n", - "SUR: UPD: it: 3500 | loss: 2.795e-02\n", - "SUR: UPD: it: 4000 | loss: 2.777e-02\n", - "SUR: UPD: it: 4500 | loss: 2.767e-02\n", - "SUR: UPD: it: 5000 | loss: 2.762e-02\n", - "SUR: UPD: it: 5500 | loss: 2.758e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 5000 | loss: 6.128e+02\n", - "VI NF (t=1.000): it: 5100 | loss: 3.791e+01\n", - "--- Saving results at iteration 5200\n", - "VI NF (t=1.000): it: 5200 | loss: 2.214e+01\n", - "VI NF (t=1.000): it: 5300 | loss: 2.036e+01\n", - "--- Saving results at iteration 5400\n", - "VI NF (t=1.000): it: 5400 | loss: 1.991e+01\n", - "VI NF (t=1.000): it: 5500 | loss: 1.628e+01\n", - "--- Saving results at iteration 5600\n", - "VI NF (t=1.000): it: 5600 | loss: 3.375e+01\n", - "VI NF (t=1.000): it: 5700 | loss: 1.804e+01\n", - "--- Saving results at iteration 5800\n", - "VI NF (t=1.000): it: 5800 | loss: 1.616e+01\n", - "VI NF (t=1.000): it: 5900 | loss: 1.888e+01\n", - "--- Saving results at iteration 6000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "2.783e-02 -> 1.921e-01\n", - "2.631e-02 -> 1.052e-01\n", - "1.994e-02 -> 2.319e-02\n", - "\n", - "SUR: UPD: it: 0 | loss: 4.027e-02\n", - "SUR: UPD: it: 500 | loss: 4.414e-02\n", - "SUR: UPD: it: 1000 | loss: 2.548e-02\n", - "SUR: UPD: it: 1500 | loss: 2.246e-02\n", - "SUR: UPD: it: 2000 | loss: 2.021e-02\n", - "SUR: UPD: it: 2500 | loss: 1.934e-02\n", - "SUR: UPD: it: 3000 | loss: 1.814e-02\n", - "SUR: UPD: it: 3500 | loss: 1.764e-02\n", - "SUR: UPD: it: 4000 | loss: 1.713e-02\n", - "SUR: UPD: it: 4500 | loss: 1.670e-02\n", - "SUR: UPD: it: 5000 | loss: 1.645e-02\n", - "SUR: UPD: it: 5500 | loss: 1.628e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 6000 | loss: 2.636e+02\n", - "VI NF (t=1.000): it: 6100 | loss: 7.582e+01\n", - "--- Saving results at iteration 6200\n", - "VI NF (t=1.000): it: 6200 | loss: 6.861e+01\n", - "VI NF (t=1.000): it: 6300 | loss: 6.603e+01\n", - "--- Saving results at iteration 6400\n", - "VI NF (t=1.000): it: 6400 | loss: 6.551e+01\n", - "VI NF (t=1.000): it: 6500 | loss: 6.558e+01\n", - "--- Saving results at iteration 6600\n", - "VI NF (t=1.000): it: 6600 | loss: 6.569e+01\n", - "VI NF (t=1.000): it: 6700 | loss: 6.811e+01\n", - "--- Saving results at iteration 6800\n", - "VI NF (t=1.000): it: 6800 | loss: 6.615e+01\n", - "VI NF (t=1.000): it: 6900 | loss: 6.424e+01\n", - "--- Saving results at iteration 7000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "2.875e-02 -> 4.979e-02\n", - "3.267e-02 -> 1.424e-01\n", - "1.552e-01 -> 1.552e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 2.089e-02\n", - "SUR: UPD: it: 500 | loss: 2.713e-02\n", - "SUR: UPD: it: 1000 | loss: 2.313e-02\n", - "SUR: UPD: it: 1500 | loss: 1.882e-02\n", - "SUR: UPD: it: 2000 | loss: 1.801e-02\n", - "SUR: UPD: it: 2500 | loss: 1.742e-02\n", - "SUR: UPD: it: 3000 | loss: 1.716e-02\n", - "SUR: UPD: it: 3500 | loss: 1.710e-02\n", - "SUR: UPD: it: 4000 | loss: 1.705e-02\n", - "SUR: UPD: it: 4500 | loss: 1.702e-02\n", - "SUR: UPD: it: 5000 | loss: 1.701e-02\n", - "SUR: UPD: it: 5500 | loss: 1.699e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 7000 | loss: 1.737e+02\n", - "VI NF (t=1.000): it: 7100 | loss: 3.313e+01\n", - "--- Saving results at iteration 7200\n", - "VI NF (t=1.000): it: 7200 | loss: 1.676e+01\n", - "VI NF (t=1.000): it: 7300 | loss: 1.609e+01\n", - "--- Saving results at iteration 7400\n", - "VI NF (t=1.000): it: 7400 | loss: 1.540e+01\n", - "VI NF (t=1.000): it: 7500 | loss: 1.544e+01\n", - "--- Saving results at iteration 7600\n", - "VI NF (t=1.000): it: 7600 | loss: 1.538e+01\n", - "VI NF (t=1.000): it: 7700 | loss: 1.544e+01\n", - "--- Saving results at iteration 7800\n", - "VI NF (t=1.000): it: 7800 | loss: 1.539e+01\n", - "VI NF (t=1.000): it: 7900 | loss: 1.546e+01\n", - "--- Saving results at iteration 8000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "2.039e-02 -> 7.366e-04\n", - "4.292e-02 -> 8.804e-02\n", - "1.017e-01 -> 1.017e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 2.772e-02\n", - "SUR: UPD: it: 500 | loss: 3.406e-02\n", - "SUR: UPD: it: 1000 | loss: 2.162e-02\n", - "SUR: UPD: it: 1500 | loss: 1.723e-02\n", - "SUR: UPD: it: 2000 | loss: 1.722e-02\n", - "SUR: UPD: it: 2500 | loss: 1.477e-02\n", - "SUR: UPD: it: 3000 | loss: 1.461e-02\n", - "SUR: UPD: it: 3500 | loss: 1.451e-02\n", - "SUR: UPD: it: 4000 | loss: 1.447e-02\n", - "SUR: UPD: it: 4500 | loss: 1.444e-02\n", - "SUR: UPD: it: 5000 | loss: 1.442e-02\n", - "SUR: UPD: it: 5500 | loss: 1.440e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 8000 | loss: 3.326e+02\n", - "VI NF (t=1.000): it: 8100 | loss: 1.369e+01\n", - "--- Saving results at iteration 8200\n", - "VI NF (t=1.000): it: 8200 | loss: 1.329e+01\n", - "VI NF (t=1.000): it: 8300 | loss: 1.301e+01\n", - "--- Saving results at iteration 8400\n", - "VI NF (t=1.000): it: 8400 | loss: 1.269e+01\n", - "VI NF (t=1.000): it: 8500 | loss: 1.264e+01\n", - "--- Saving results at iteration 8600\n", - "VI NF (t=1.000): it: 8600 | loss: 1.261e+01\n", - "VI NF (t=1.000): it: 8700 | loss: 1.229e+01\n", - "--- Saving results at iteration 8800\n", - "VI NF (t=1.000): it: 8800 | loss: 1.231e+01\n", - "VI NF (t=1.000): it: 8900 | loss: 1.232e+01\n", - "--- Saving results at iteration 9000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "3.846e-02 -> 6.436e-02\n", - "1.730e-01 -> 1.730e-01\n", - "3.521e-01 -> 3.521e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 2.272e-02\n", - "SUR: UPD: it: 500 | loss: 3.300e-02\n", - "SUR: UPD: it: 1000 | loss: 2.415e-02\n", - "SUR: UPD: it: 1500 | loss: 1.757e-02\n", - "SUR: UPD: it: 2000 | loss: 1.608e-02\n", - "SUR: UPD: it: 2500 | loss: 1.586e-02\n", - "SUR: UPD: it: 3000 | loss: 1.494e-02\n", - "SUR: UPD: it: 3500 | loss: 1.483e-02\n", - "SUR: UPD: it: 4000 | loss: 1.472e-02\n", - "SUR: UPD: it: 4500 | loss: 1.464e-02\n", - "SUR: UPD: it: 5000 | loss: 1.457e-02\n", - "SUR: UPD: it: 5500 | loss: 1.452e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 9000 | loss: 7.471e+01\n", - "VI NF (t=1.000): it: 9100 | loss: 1.229e+01\n", - "--- Saving results at iteration 9200\n", - "VI NF (t=1.000): it: 9200 | loss: 1.247e+01\n", - "VI NF (t=1.000): it: 9300 | loss: 1.216e+01\n", - "--- Saving results at iteration 9400\n", - "VI NF (t=1.000): it: 9400 | loss: 1.212e+01\n", - "VI NF (t=1.000): it: 9500 | loss: 1.204e+01\n", - "--- Saving results at iteration 9600\n", - "VI NF (t=1.000): it: 9600 | loss: 1.196e+01\n", - "VI NF (t=1.000): it: 9700 | loss: 1.207e+01\n", - "--- Saving results at iteration 9800\n", - "VI NF (t=1.000): it: 9800 | loss: 1.209e+01\n", - "VI NF (t=1.000): it: 9900 | loss: 1.214e+01\n", - "--- Saving results at iteration 10000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "3.490e-02 -> 1.518e-01\n", - "7.003e-02 -> 1.996e-01\n", - "1.032e-01 -> 1.032e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 2.231e-02\n", - "SUR: UPD: it: 500 | loss: 4.040e-02\n", - "SUR: UPD: it: 1000 | loss: 2.491e-02\n", - "SUR: UPD: it: 1500 | loss: 2.291e-02\n", - "SUR: UPD: it: 2000 | loss: 1.977e-02\n", - "SUR: UPD: it: 2500 | loss: 1.854e-02\n", - "SUR: UPD: it: 3000 | loss: 1.834e-02\n", - "SUR: UPD: it: 3500 | loss: 1.828e-02\n", - "SUR: UPD: it: 4000 | loss: 1.826e-02\n", - "SUR: UPD: it: 4500 | loss: 1.822e-02\n", - "SUR: UPD: it: 5000 | loss: 1.802e-02\n", - "SUR: UPD: it: 5500 | loss: 1.757e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 10000 | loss: 6.577e+01\n", - "VI NF (t=1.000): it: 10100 | loss: 8.531e+00\n", - "--- Saving results at iteration 10200\n", - "VI NF (t=1.000): it: 10200 | loss: 8.541e+00\n", - "VI NF (t=1.000): it: 10300 | loss: 8.465e+00\n", - "--- Saving results at iteration 10400\n", - "VI NF (t=1.000): it: 10400 | loss: 8.501e+00\n", - "VI NF (t=1.000): it: 10500 | loss: 8.413e+00\n", - "--- Saving results at iteration 10600\n", - "VI NF (t=1.000): it: 10600 | loss: 8.319e+00\n", - "VI NF (t=1.000): it: 10700 | loss: 8.323e+00\n", - "--- Saving results at iteration 10800\n", - "VI NF (t=1.000): it: 10800 | loss: 8.116e+00\n", - "VI NF (t=1.000): it: 10900 | loss: 8.229e+00\n", - "--- Saving results at iteration 11000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "3.900e-02 -> 1.633e-01\n", - "3.385e-02 -> 2.657e-02\n", - "1.064e-01 -> 1.064e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 1.816e-02\n", - "SUR: UPD: it: 500 | loss: 6.173e-02\n", - "SUR: UPD: it: 1000 | loss: 2.283e-02\n", - "SUR: UPD: it: 1500 | loss: 2.227e-02\n", - "SUR: UPD: it: 2000 | loss: 1.935e-02\n", - "SUR: UPD: it: 2500 | loss: 1.887e-02\n", - "SUR: UPD: it: 3000 | loss: 1.812e-02\n", - "SUR: UPD: it: 3500 | loss: 1.787e-02\n", - "SUR: UPD: it: 4000 | loss: 1.775e-02\n", - "SUR: UPD: it: 4500 | loss: 1.767e-02\n", - "SUR: UPD: it: 5000 | loss: 1.740e-02\n", - "SUR: UPD: it: 5500 | loss: 1.731e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 11000 | loss: 3.733e+01\n", - "VI NF (t=1.000): it: 11100 | loss: 8.160e+00\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--- Saving results at iteration 11200\n", - "VI NF (t=1.000): it: 11200 | loss: 8.262e+00\n", - "VI NF (t=1.000): it: 11300 | loss: 8.345e+00\n", - "--- Saving results at iteration 11400\n", - "VI NF (t=1.000): it: 11400 | loss: 8.299e+00\n", - "VI NF (t=1.000): it: 11500 | loss: 8.185e+00\n", - "--- Saving results at iteration 11600\n", - "VI NF (t=1.000): it: 11600 | loss: 8.205e+00\n", - "VI NF (t=1.000): it: 11700 | loss: 8.006e+00\n", - "--- Saving results at iteration 11800\n", - "VI NF (t=1.000): it: 11800 | loss: 8.015e+00\n", - "VI NF (t=1.000): it: 11900 | loss: 7.950e+00\n", - "--- Saving results at iteration 12000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "5.648e-02 -> 8.101e-02\n", - "5.597e-02 -> 4.461e-02\n", - "1.413e-01 -> 1.413e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 1.698e-02\n", - "SUR: UPD: it: 500 | loss: 2.345e-02\n", - "SUR: UPD: it: 1000 | loss: 2.031e-02\n", - "SUR: UPD: it: 1500 | loss: 1.864e-02\n", - "SUR: UPD: it: 2000 | loss: 1.826e-02\n", - "SUR: UPD: it: 2500 | loss: 1.869e-02\n", - "SUR: UPD: it: 3000 | loss: 1.805e-02\n", - "SUR: UPD: it: 3500 | loss: 1.784e-02\n", - "SUR: UPD: it: 4000 | loss: 1.776e-02\n", - "SUR: UPD: it: 4500 | loss: 1.773e-02\n", - "SUR: UPD: it: 5000 | loss: 1.769e-02\n", - "SUR: UPD: it: 5500 | loss: 1.768e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 12000 | loss: 1.354e+01\n", - "VI NF (t=1.000): it: 12100 | loss: 8.426e+00\n", - "--- Saving results at iteration 12200\n", - "VI NF (t=1.000): it: 12200 | loss: 8.481e+00\n", - "VI NF (t=1.000): it: 12300 | loss: 8.376e+00\n", - "--- Saving results at iteration 12400\n", - "VI NF (t=1.000): it: 12400 | loss: 8.452e+00\n", - "VI NF (t=1.000): it: 12500 | loss: 8.355e+00\n", - "--- Saving results at iteration 12600\n", - "VI NF (t=1.000): it: 12600 | loss: 8.457e+00\n", - "VI NF (t=1.000): it: 12700 | loss: 8.435e+00\n", - "--- Saving results at iteration 12800\n", - "VI NF (t=1.000): it: 12800 | loss: 8.424e+00\n", - "VI NF (t=1.000): it: 12900 | loss: 8.247e+00\n", - "--- Saving results at iteration 13000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "2.207e-01 -> 2.207e-01\n", - "2.592e-01 -> 2.592e-01\n", - "4.852e-01 -> 4.852e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 5.004e-02\n", - "SUR: UPD: it: 500 | loss: 5.291e-02\n", - "SUR: UPD: it: 1000 | loss: 3.257e-02\n", - "SUR: UPD: it: 1500 | loss: 2.680e-02\n", - "SUR: UPD: it: 2000 | loss: 2.533e-02\n", - "SUR: UPD: it: 2500 | loss: 2.314e-02\n", - "SUR: UPD: it: 3000 | loss: 2.243e-02\n", - "SUR: UPD: it: 3500 | loss: 2.152e-02\n", - "SUR: UPD: it: 4000 | loss: 2.126e-02\n", - "SUR: UPD: it: 4500 | loss: 2.067e-02\n", - "SUR: UPD: it: 5000 | loss: 2.036e-02\n", - "SUR: UPD: it: 5500 | loss: 2.020e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 13000 | loss: 9.114e+01\n", - "VI NF (t=1.000): it: 13100 | loss: 2.167e+01\n", - "--- Saving results at iteration 13200\n", - "VI NF (t=1.000): it: 13200 | loss: 1.573e+01\n", - "VI NF (t=1.000): it: 13300 | loss: 1.178e+01\n", - "--- Saving results at iteration 13400\n", - "VI NF (t=1.000): it: 13400 | loss: 1.247e+01\n", - "VI NF (t=1.000): it: 13500 | loss: 1.090e+01\n", - "--- Saving results at iteration 13600\n", - "VI NF (t=1.000): it: 13600 | loss: 9.036e+00\n", - "VI NF (t=1.000): it: 13700 | loss: 8.584e+00\n", - "--- Saving results at iteration 13800\n", - "VI NF (t=1.000): it: 13800 | loss: 8.311e+00\n", - "VI NF (t=1.000): it: 13900 | loss: 8.694e+00\n", - "--- Saving results at iteration 14000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "1.368e-01 -> 1.368e-01\n", - "3.581e-02 -> 3.616e-02\n", - "1.221e-01 -> 1.221e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 2.010e-02\n", - "SUR: UPD: it: 500 | loss: 6.309e-02\n", - "SUR: UPD: it: 1000 | loss: 2.895e-02\n", - "SUR: UPD: it: 1500 | loss: 2.875e-02\n", - "SUR: UPD: it: 2000 | loss: 2.347e-02\n", - "SUR: UPD: it: 2500 | loss: 2.256e-02\n", - "SUR: UPD: it: 3000 | loss: 2.229e-02\n", - "SUR: UPD: it: 3500 | loss: 2.201e-02\n", - "SUR: UPD: it: 4000 | loss: 2.186e-02\n", - "SUR: UPD: it: 4500 | loss: 2.179e-02\n", - "SUR: UPD: it: 5000 | loss: 2.174e-02\n", - "SUR: UPD: it: 5500 | loss: 2.170e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 14000 | loss: 1.318e+02\n", - "VI NF (t=1.000): it: 14100 | loss: 8.762e+00\n", - "--- Saving results at iteration 14200\n", - "VI NF (t=1.000): it: 14200 | loss: 8.849e+00\n", - "VI NF (t=1.000): it: 14300 | loss: 8.601e+00\n", - "--- Saving results at iteration 14400\n", - "VI NF (t=1.000): it: 14400 | loss: 8.283e+00\n", - "VI NF (t=1.000): it: 14500 | loss: 8.712e+00\n", - "--- Saving results at iteration 14600\n", - "VI NF (t=1.000): it: 14600 | loss: 8.355e+00\n", - "VI NF (t=1.000): it: 14700 | loss: 8.287e+00\n", - "--- Saving results at iteration 14800\n", - "VI NF (t=1.000): it: 14800 | loss: 8.293e+00\n", - "VI NF (t=1.000): it: 14900 | loss: 8.351e+00\n", - "--- Saving results at iteration 15000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "6.254e-03 -> 3.839e-02\n", - "9.847e-03 -> 4.599e-02\n", - "4.278e-02 -> 1.183e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 2.331e-02\n", - "SUR: UPD: it: 500 | loss: 3.625e-02\n", - "SUR: UPD: it: 1000 | loss: 2.875e-02\n", - "SUR: UPD: it: 1500 | loss: 2.357e-02\n", - "SUR: UPD: it: 2000 | loss: 2.148e-02\n", - "SUR: UPD: it: 2500 | loss: 2.063e-02\n", - "SUR: UPD: it: 3000 | loss: 2.010e-02\n", - "SUR: UPD: it: 3500 | loss: 1.995e-02\n", - "SUR: UPD: it: 4000 | loss: 1.983e-02\n", - "SUR: UPD: it: 4500 | loss: 1.970e-02\n", - "SUR: UPD: it: 5000 | loss: 1.959e-02\n", - "SUR: UPD: it: 5500 | loss: 1.949e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 15000 | loss: 1.412e+01\n", - "VI NF (t=1.000): it: 15100 | loss: 8.362e+00\n", - "--- Saving results at iteration 15200\n", - "VI NF (t=1.000): it: 15200 | loss: 8.607e+00\n", - "VI NF (t=1.000): it: 15300 | loss: 8.312e+00\n", - "--- Saving results at iteration 15400\n", - "VI NF (t=1.000): it: 15400 | loss: 8.334e+00\n", - "VI NF (t=1.000): it: 15500 | loss: 8.247e+00\n", - "--- Saving results at iteration 15600\n", - "VI NF (t=1.000): it: 15600 | loss: 8.416e+00\n", - "VI NF (t=1.000): it: 15700 | loss: 1.009e+01\n", - "--- Saving results at iteration 15800\n", - "VI NF (t=1.000): it: 15800 | loss: 8.227e+00\n", - "VI NF (t=1.000): it: 15900 | loss: 8.265e+00\n", - "--- Saving results at iteration 16000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "3.762e-01 -> 3.762e-01\n", - "2.184e-01 -> 2.184e-01\n", - "5.993e-01 -> 5.993e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 1.844e-02\n", - "SUR: UPD: it: 500 | loss: 6.388e-02\n", - "SUR: UPD: it: 1000 | loss: 4.104e-02\n", - "SUR: UPD: it: 1500 | loss: 3.234e-02\n", - "SUR: UPD: it: 2000 | loss: 2.950e-02\n", - "SUR: UPD: it: 2500 | loss: 2.727e-02\n", - "SUR: UPD: it: 3000 | loss: 2.658e-02\n", - "SUR: UPD: it: 3500 | loss: 2.597e-02\n", - "SUR: UPD: it: 4000 | loss: 2.568e-02\n", - "SUR: UPD: it: 4500 | loss: 2.542e-02\n", - "SUR: UPD: it: 5000 | loss: 2.518e-02\n", - "SUR: UPD: it: 5500 | loss: 2.500e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 16000 | loss: 1.315e+01\n", - "VI NF (t=1.000): it: 16100 | loss: 7.895e+00\n", - "--- Saving results at iteration 16200\n", - "VI NF (t=1.000): it: 16200 | loss: 7.920e+00\n", - "VI NF (t=1.000): it: 16300 | loss: 7.633e+00\n", - "--- Saving results at iteration 16400\n", - "VI NF (t=1.000): it: 16400 | loss: 7.941e+00\n", - "VI NF (t=1.000): it: 16500 | loss: 7.943e+00\n", - "--- Saving results at iteration 16600\n", - "VI NF (t=1.000): it: 16600 | loss: 7.543e+00\n", - "VI NF (t=1.000): it: 16700 | loss: 7.591e+00\n", - "--- Saving results at iteration 16800\n", - "VI NF (t=1.000): it: 16800 | loss: 7.587e+00\n", - "VI NF (t=1.000): it: 16900 | loss: 7.716e+00\n", - "--- Saving results at iteration 17000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "2.002e-01 -> 2.002e-01\n", - "3.297e-01 -> 3.297e-01\n", - "7.690e-01 -> 7.690e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 2.912e-02\n", - "SUR: UPD: it: 500 | loss: 4.944e-02\n", - "SUR: UPD: it: 1000 | loss: 3.491e-02\n", - "SUR: UPD: it: 1500 | loss: 2.893e-02\n", - "SUR: UPD: it: 2000 | loss: 2.663e-02\n", - "SUR: UPD: it: 2500 | loss: 2.504e-02\n", - "SUR: UPD: it: 3000 | loss: 2.351e-02\n", - "SUR: UPD: it: 3500 | loss: 2.124e-02\n", - "SUR: UPD: it: 4000 | loss: 2.047e-02\n", - "SUR: UPD: it: 4500 | loss: 2.019e-02\n", - "SUR: UPD: it: 5000 | loss: 2.004e-02\n", - "SUR: UPD: it: 5500 | loss: 1.995e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 17000 | loss: 2.869e+01\n", - "VI NF (t=1.000): it: 17100 | loss: 8.180e+00\n", - "--- Saving results at iteration 17200\n", - "VI NF (t=1.000): it: 17200 | loss: 8.140e+00\n", - "VI NF (t=1.000): it: 17300 | loss: 8.126e+00\n", - "--- Saving results at iteration 17400\n", - "VI NF (t=1.000): it: 17400 | loss: 8.364e+00\n", - "VI NF (t=1.000): it: 17500 | loss: 8.034e+00\n", - "--- Saving results at iteration 17600\n", - "VI NF (t=1.000): it: 17600 | loss: 7.958e+00\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "VI NF (t=1.000): it: 17700 | loss: 8.066e+00\n", - "--- Saving results at iteration 17800\n", - "VI NF (t=1.000): it: 17800 | loss: 7.991e+00\n", - "VI NF (t=1.000): it: 17900 | loss: 8.044e+00\n", - "--- Saving results at iteration 18000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "8.472e-02 -> 1.976e-01\n", - "2.813e-02 -> 1.388e-01\n", - "1.240e-01 -> 1.240e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 1.982e-02\n", - "SUR: UPD: it: 500 | loss: 3.887e-02\n", - "SUR: UPD: it: 1000 | loss: 3.254e-02\n", - "SUR: UPD: it: 1500 | loss: 2.662e-02\n", - "SUR: UPD: it: 2000 | loss: 2.565e-02\n", - "SUR: UPD: it: 2500 | loss: 2.448e-02\n", - "SUR: UPD: it: 3000 | loss: 2.400e-02\n", - "SUR: UPD: it: 3500 | loss: 2.394e-02\n", - "SUR: UPD: it: 4000 | loss: 2.391e-02\n", - "SUR: UPD: it: 4500 | loss: 2.388e-02\n", - "SUR: UPD: it: 5000 | loss: 2.386e-02\n", - "SUR: UPD: it: 5500 | loss: 2.386e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 18000 | loss: 8.273e+00\n", - "VI NF (t=1.000): it: 18100 | loss: 7.797e+00\n", - "--- Saving results at iteration 18200\n", - "VI NF (t=1.000): it: 18200 | loss: 8.006e+00\n", - "VI NF (t=1.000): it: 18300 | loss: 7.764e+00\n", - "--- Saving results at iteration 18400\n", - "VI NF (t=1.000): it: 18400 | loss: 7.838e+00\n", - "VI NF (t=1.000): it: 18500 | loss: 7.629e+00\n", - "--- Saving results at iteration 18600\n", - "VI NF (t=1.000): it: 18600 | loss: 7.693e+00\n", - "VI NF (t=1.000): it: 18700 | loss: 7.658e+00\n", - "--- Saving results at iteration 18800\n", - "VI NF (t=1.000): it: 18800 | loss: 7.780e+00\n", - "VI NF (t=1.000): it: 18900 | loss: 7.740e+00\n", - "--- Saving results at iteration 19000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "8.146e-02 -> 9.206e-03\n", - "3.396e-02 -> 1.539e-01\n", - "1.390e-01 -> 1.390e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 2.359e-02\n", - "SUR: UPD: it: 500 | loss: 4.769e-02\n", - "SUR: UPD: it: 1000 | loss: 3.693e-02\n", - "SUR: UPD: it: 1500 | loss: 2.692e-02\n", - "SUR: UPD: it: 2000 | loss: 2.481e-02\n", - "SUR: UPD: it: 2500 | loss: 2.386e-02\n", - "SUR: UPD: it: 3000 | loss: 2.373e-02\n", - "SUR: UPD: it: 3500 | loss: 2.364e-02\n", - "SUR: UPD: it: 4000 | loss: 2.359e-02\n", - "SUR: UPD: it: 4500 | loss: 2.358e-02\n", - "SUR: UPD: it: 5000 | loss: 2.358e-02\n", - "SUR: UPD: it: 5500 | loss: 2.357e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 19000 | loss: 7.600e+00\n", - "VI NF (t=1.000): it: 19100 | loss: 7.881e+00\n", - "--- Saving results at iteration 19200\n", - "VI NF (t=1.000): it: 19200 | loss: 7.555e+00\n", - "VI NF (t=1.000): it: 19300 | loss: 7.515e+00\n", - "--- Saving results at iteration 19400\n", - "VI NF (t=1.000): it: 19400 | loss: 7.828e+00\n", - "VI NF (t=1.000): it: 19500 | loss: 7.547e+00\n", - "--- Saving results at iteration 19600\n", - "VI NF (t=1.000): it: 19600 | loss: 7.576e+00\n", - "VI NF (t=1.000): it: 19700 | loss: 7.657e+00\n", - "--- Saving results at iteration 19800\n", - "VI NF (t=1.000): it: 19800 | loss: 7.776e+00\n", - "VI NF (t=1.000): it: 19900 | loss: 7.884e+00\n", - "--- Saving results at iteration 20000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "9.719e-01 -> 9.719e-01\n", - "5.516e-01 -> 5.516e-01\n", - "1.492e+00 -> 1.492e+00\n", - "\n", - "SUR: UPD: it: 0 | loss: 2.444e-02\n", - "SUR: UPD: it: 500 | loss: 4.757e-02\n", - "SUR: UPD: it: 1000 | loss: 3.443e-02\n", - "SUR: UPD: it: 1500 | loss: 2.587e-02\n", - "SUR: UPD: it: 2000 | loss: 2.446e-02\n", - "SUR: UPD: it: 2500 | loss: 2.361e-02\n", - "SUR: UPD: it: 3000 | loss: 2.309e-02\n", - "SUR: UPD: it: 3500 | loss: 2.296e-02\n", - "SUR: UPD: it: 4000 | loss: 2.290e-02\n", - "SUR: UPD: it: 4500 | loss: 2.286e-02\n", - "SUR: UPD: it: 5000 | loss: 2.283e-02\n", - "SUR: UPD: it: 5500 | loss: 2.281e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 20000 | loss: 1.059e+01\n", - "VI NF (t=1.000): it: 20100 | loss: 7.395e+00\n", - "--- Saving results at iteration 20200\n", - "VI NF (t=1.000): it: 20200 | loss: 7.442e+00\n", - "VI NF (t=1.000): it: 20300 | loss: 7.673e+00\n", - "--- Saving results at iteration 20400\n", - "VI NF (t=1.000): it: 20400 | loss: 7.371e+00\n", - "VI NF (t=1.000): it: 20500 | loss: 7.440e+00\n", - "--- Saving results at iteration 20600\n", - "VI NF (t=1.000): it: 20600 | loss: 7.472e+00\n", - "VI NF (t=1.000): it: 20700 | loss: 7.569e+00\n", - "--- Saving results at iteration 20800\n", - "VI NF (t=1.000): it: 20800 | loss: 7.531e+00\n", - "VI NF (t=1.000): it: 20900 | loss: 7.463e+00\n", - "--- Saving results at iteration 21000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "2.064e-01 -> 2.064e-01\n", - "2.456e-01 -> 2.456e-01\n", - "5.328e-01 -> 5.328e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 2.463e-02\n", - "SUR: UPD: it: 500 | loss: 4.224e-02\n", - "SUR: UPD: it: 1000 | loss: 1.952e-02\n", - "SUR: UPD: it: 1500 | loss: 1.620e-02\n", - "SUR: UPD: it: 2000 | loss: 1.566e-02\n", - "SUR: UPD: it: 2500 | loss: 1.457e-02\n", - "SUR: UPD: it: 3000 | loss: 1.404e-02\n", - "SUR: UPD: it: 3500 | loss: 1.375e-02\n", - "SUR: UPD: it: 4000 | loss: 1.363e-02\n", - "SUR: UPD: it: 4500 | loss: 1.357e-02\n", - "SUR: UPD: it: 5000 | loss: 1.350e-02\n", - "SUR: UPD: it: 5500 | loss: 1.344e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 21000 | loss: 1.067e+01\n", - "VI NF (t=1.000): it: 21100 | loss: 8.688e+00\n", - "--- Saving results at iteration 21200\n", - "VI NF (t=1.000): it: 21200 | loss: 8.499e+00\n", - "VI NF (t=1.000): it: 21300 | loss: 8.798e+00\n", - "--- Saving results at iteration 21400\n", - "VI NF (t=1.000): it: 21400 | loss: 8.715e+00\n", - "VI NF (t=1.000): it: 21500 | loss: 8.587e+00\n", - "--- Saving results at iteration 21600\n", - "VI NF (t=1.000): it: 21600 | loss: 8.636e+00\n", - "VI NF (t=1.000): it: 21700 | loss: 8.520e+00\n", - "--- Saving results at iteration 21800\n", - "VI NF (t=1.000): it: 21800 | loss: 8.613e+00\n", - "VI NF (t=1.000): it: 21900 | loss: 8.593e+00\n", - "--- Saving results at iteration 22000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "4.674e-02 -> 4.935e-02\n", - "1.328e-02 -> 2.843e-02\n", - "8.518e-02 -> 1.284e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 1.253e-02\n", - "SUR: UPD: it: 500 | loss: 4.190e-02\n", - "SUR: UPD: it: 1000 | loss: 3.176e-02\n", - "SUR: UPD: it: 1500 | loss: 2.807e-02\n", - "SUR: UPD: it: 2000 | loss: 2.382e-02\n", - "SUR: UPD: it: 2500 | loss: 2.205e-02\n", - "SUR: UPD: it: 3000 | loss: 2.200e-02\n", - "SUR: UPD: it: 3500 | loss: 2.194e-02\n", - "SUR: UPD: it: 4000 | loss: 2.187e-02\n", - "SUR: UPD: it: 4500 | loss: 2.184e-02\n", - "SUR: UPD: it: 5000 | loss: 2.182e-02\n", - "SUR: UPD: it: 5500 | loss: 2.177e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 22000 | loss: 7.514e+00\n", - "VI NF (t=1.000): it: 22100 | loss: 7.504e+00\n", - "--- Saving results at iteration 22200\n", - "VI NF (t=1.000): it: 22200 | loss: 7.240e+00\n", - "VI NF (t=1.000): it: 22300 | loss: 7.377e+00\n", - "--- Saving results at iteration 22400\n", - "VI NF (t=1.000): it: 22400 | loss: 7.292e+00\n", - "VI NF (t=1.000): it: 22500 | loss: 7.325e+00\n", - "--- Saving results at iteration 22600\n", - "VI NF (t=1.000): it: 22600 | loss: 7.229e+00\n", - "VI NF (t=1.000): it: 22700 | loss: 7.396e+00\n", - "--- Saving results at iteration 22800\n", - "VI NF (t=1.000): it: 22800 | loss: 7.363e+00\n", - "VI NF (t=1.000): it: 22900 | loss: 7.319e+00\n", - "--- Saving results at iteration 23000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "1.420e-01 -> 1.420e-01\n", - "9.494e-02 -> 1.663e-01\n", - "2.319e-01 -> 2.319e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 2.139e-02\n", - "SUR: UPD: it: 500 | loss: 4.424e-02\n", - "SUR: UPD: it: 1000 | loss: 3.247e-02\n", - "SUR: UPD: it: 1500 | loss: 2.408e-02\n", - "SUR: UPD: it: 2000 | loss: 2.311e-02\n", - "SUR: UPD: it: 2500 | loss: 2.188e-02\n", - "SUR: UPD: it: 3000 | loss: 2.148e-02\n", - "SUR: UPD: it: 3500 | loss: 2.146e-02\n", - "SUR: UPD: it: 4000 | loss: 2.133e-02\n", - "SUR: UPD: it: 4500 | loss: 2.130e-02\n", - "SUR: UPD: it: 5000 | loss: 2.123e-02\n", - "SUR: UPD: it: 5500 | loss: 2.118e-02\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 23000 | loss: 7.743e+00\n", - "VI NF (t=1.000): it: 23100 | loss: 7.326e+00\n", - "--- Saving results at iteration 23200\n", - "VI NF (t=1.000): it: 23200 | loss: 7.543e+00\n", - "VI NF (t=1.000): it: 23300 | loss: 7.416e+00\n", - "--- Saving results at iteration 23400\n", - "VI NF (t=1.000): it: 23400 | loss: 7.091e+00\n", - "VI NF (t=1.000): it: 23500 | loss: 7.395e+00\n", - "--- Saving results at iteration 23600\n", - "VI NF (t=1.000): it: 23600 | loss: 7.419e+00\n", - "VI NF (t=1.000): it: 23700 | loss: 7.228e+00\n", - "--- Saving results at iteration 23800\n", - "VI NF (t=1.000): it: 23800 | loss: 7.377e+00\n", - "VI NF (t=1.000): it: 23900 | loss: 7.400e+00\n", - "--- Saving results at iteration 24000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "7.360e-01 -> 7.360e-01\n", - "3.871e-01 -> 3.871e-01\n", - "9.115e-01 -> 9.115e-01\n", - "\n", - "SUR: UPD: it: 0 | loss: 2.025e-02\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SUR: UPD: it: 500 | loss: 1.830e-01\n", - "SUR: UPD: it: 1000 | loss: 1.345e-01\n", - "SUR: UPD: it: 1500 | loss: 1.244e-01\n", - "SUR: UPD: it: 2000 | loss: 1.238e-01\n", - "SUR: UPD: it: 2500 | loss: 1.230e-01\n", - "SUR: UPD: it: 3000 | loss: 1.224e-01\n", - "SUR: UPD: it: 3500 | loss: 1.222e-01\n", - "SUR: UPD: it: 4000 | loss: 1.221e-01\n", - "SUR: UPD: it: 4500 | loss: 1.221e-01\n", - "SUR: UPD: it: 5000 | loss: 1.220e-01\n", - "SUR: UPD: it: 5500 | loss: 1.220e-01\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 24000 | loss: 2.198e+01\n", - "VI NF (t=1.000): it: 24100 | loss: 8.043e+00\n", - "--- Saving results at iteration 24200\n", - "VI NF (t=1.000): it: 24200 | loss: 7.933e+00\n", - "VI NF (t=1.000): it: 24300 | loss: 7.971e+00\n", - "--- Saving results at iteration 24400\n", - "VI NF (t=1.000): it: 24400 | loss: 7.971e+00\n", - "VI NF (t=1.000): it: 24500 | loss: 7.792e+00\n", - "--- Saving results at iteration 24600\n", - "VI NF (t=1.000): it: 24600 | loss: 7.757e+00\n", - "VI NF (t=1.000): it: 24700 | loss: 7.741e+00\n", - "--- Saving results at iteration 24800\n", - "VI NF (t=1.000): it: 24800 | loss: 7.848e+00\n", - "VI NF (t=1.000): it: 24900 | loss: 7.997e+00\n", - "--- Saving results at iteration 25000\n", - "\n", - "--- Updating surrogate model\n", - "\n", - "Std before inflation -> Std after inflation\n", - "7.188e-01 -> 7.188e-01\n", - "3.309e-01 -> 3.309e-01\n", - "1.044e+00 -> 1.044e+00\n", - "\n", - "SUR: UPD: it: 0 | loss: 1.208e-01\n", - "SUR: UPD: it: 500 | loss: 1.643e-01\n", - "SUR: UPD: it: 1000 | loss: 1.268e-01\n", - "SUR: UPD: it: 1500 | loss: 1.236e-01\n", - "SUR: UPD: it: 2000 | loss: 1.217e-01\n", - "SUR: UPD: it: 2500 | loss: 1.215e-01\n", - "SUR: UPD: it: 3000 | loss: 1.206e-01\n", - "SUR: UPD: it: 3500 | loss: 1.206e-01\n", - "SUR: UPD: it: 4000 | loss: 1.206e-01\n", - "SUR: UPD: it: 4500 | loss: 1.205e-01\n", - "SUR: UPD: it: 5000 | loss: 1.205e-01\n", - "SUR: UPD: it: 5500 | loss: 1.205e-01\n", - "\n", - "--- Surrogate model updated\n", - "\n", - "VI NF (t=1.000): it: 25000 | loss: 8.470e+00\n", - "\n", - "--- Simulation completed!\n" - ] - } - ], - "source": [ - "print('')\n", - "print('--- Temporary TEST: Physics Example - NOFAS')\n", - "print('')\n", - "\n", - "# Experiment Setting\n", - "exp = experiment()\n", - "exp.name = \"phys\" # str: Name of experiment\n", - "exp.flow_type = 'maf' # str: Type of flow\n", - "exp.n_blocks = 5 # int: Number of layers \n", - "exp.hidden_size = 100 # int: Hidden layer size for MADE in each layer \n", - "exp.n_hidden = 1 # int: Number of hidden layers in each MADE \n", - "exp.activation_fn = 'relu' # str: Actication function used \n", - "exp.input_order = 'sequential' # str: Input order for create_mask \n", - "exp.batch_norm_order = True # boolean: Order to decide if batch_norm is used \n", - "exp.sampling_interval = 5000 # int: How often to sample from normalizing flow\n", - "\n", - "exp.input_size = 3 # int: Dimensionality of input \n", - "exp.batch_size = 250 # int: Number of samples generated \n", - "exp.true_data_num = 2 # double: number of true model evaluated \n", - "exp.n_iter = 25001 # int: Number of iterations \n", - "exp.lr = 0.01 # float: Learning rate \n", - "exp.lr_decay = 0.9999 # float: Learning rate decay \n", - "exp.log_interval = 100 # int: How often to show loss stat \n", - "\n", - "exp.run_nofas = True\n", - "exp.annealing = False\n", - "exp.calibrate_interval = 1000 # int: How often to update surrogate model default 1000\n", - "exp.budget = 260 # int: Total number of true model evaluation\n", - "exp.surr_pre_it = 20000 #:int: Number of pre-training iterations for surrogate model\n", - "exp.surr_upd_it = 6000 #:int: Number of iterations for the surrogate model update\n", - "exp.surr_folder = \"./\"\n", - "exp.use_new_surr = True\n", - "\n", - "exp.output_dir = './' + exp.name\n", - "exp.results_file = 'results.txt'\n", - "exp.log_file = 'log.txt'\n", - "exp.samples_file = 'samples.txt'\n", - "exp.seed = random.randint(0, 10 ** 9) # int: Random seed used\n", - "exp.n_sample = 5000 # int: Total number of iterations\n", - "exp.no_cuda = True\n", - "\n", - "exp.optimizer = 'RMSprop'\n", - "exp.lr_scheduler = 'ExponentialLR'\n", - "\n", - "exp.device = torch.device('cuda:0' if torch.cuda.is_available() and not exp.no_cuda else 'cpu')\n", - "print('--- Running on device: '+ str(exp.device))\n", - "print('')\n", - "\n", - "# Define transformation\n", - "trsf_info = [['identity',0.0,0.0,0.0,0.0],\n", - " ['identity',0.0,0.0,0.0,0.0],\n", - " ['linear',-3,3,30.0,80.0]]\n", - "trsf = Transformation(trsf_info) \n", - "exp.transform = trsf\n", - "\n", - "# Set model and get data\n", - "exp.model = model\n", - "model.data = np.loadtxt('./data_phys.txt')\n", - "\n", - "# Define surrogate\n", - "exp.surrogate = Surrogate(exp.name, lambda x: model.solve_t(trsf.forward(x)), exp.input_size, 3, \n", - " model_folder=exp.surr_folder, limits=torch.Tensor([[0, 2], [0, 10], [-3, 3]]), \n", - " memory_len=20, device=exp.device)\n", - "surr_filename = exp.surr_folder + exp.name\n", - "if exp.use_new_surr or (not os.path.isfile(surr_filename + \".sur\")) or (not os.path.isfile(surr_filename + \".npz\")):\n", - " print(\"Warning: Surrogate model files: {0}.npz and {0}.npz could not be found. \".format(surr_filename))\n", - " # 4 samples for each dimension: pre-grid size = 16\n", - "# exp.surrogate.gen_grid(gridnum=4)\n", - " exp.surrogate.gen_grid(gridnum=3)\n", - " exp.surrogate.pre_train(exp.surr_pre_it, 0.03, 0.9999, 500, store=True)\n", - "# Load the surrogate\n", - "exp.surrogate.surrogate_load()\n", - "\n", - "output_dir = './' + exp.name\n", - "\n", - "# Define log density\n", - "def log_density(x, model, surrogate, transform):\n", - " # x contains the original, untransformed inputs\n", - " np.savetxt(output_dir + '/' + exp.name + '_x', x.detach().numpy(), newline=\"\\n\")\n", - " # Compute transformation log Jacobian\n", - " adjust = transform.compute_log_jacob_func(x)\n", - "\n", - " batch_size = x.size(0)\n", - " # Get the absolute values of the standard deviations\n", - " stds = torch.abs(model.solve_t(model.defParam)) * model.stdRatio\n", - " Data = torch.tensor(model.data).to(exp.device)\n", - " \n", - " # Check for surrogate\n", - " if surrogate:\n", - " modelOut = exp.surrogate.forward(x)\n", - " else:\n", - " modelOut = model.solve_t(transform.forward(x))\n", - "\n", - " # Eval LL\n", - " ll1 = -0.5 * np.prod(model.data.shape) * np.log(2.0 * np.pi)\n", - " ll2 = (-0.5 * model.data.shape[1] * torch.log(torch.prod(stds))).item()\n", - " ll3 = 0.0\n", - " for i in range(3):\n", - " ll3 += - 0.5 * torch.sum(((modelOut[:, i].unsqueeze(1) - Data[i, :].unsqueeze(0)) / stds[0, i]) ** 2, dim=1)\n", - " negLL = -(ll1 + ll2 + ll3)\n", - " res = -negLL.reshape(x.size(0), 1) + adjust\n", - " np.savetxt(output_dir + '/' + exp.name + '_res', res.detach().numpy(), newline=\"\\n\")\n", - " return res\n", - "\n", - "# Assign log-density model\n", - "exp.model_logdensity = lambda x: log_density(x, model, exp.surrogate, trsf)\n", - "\n", - "# Run VI\n", - "exp.run()" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 95, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from IPython.display import IFrame\n", - "IFrame(\"./data_plot_phys_25000_0_1.pdf\", width=200, height=155)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tutorial/data_phys.txt b/tutorial/data_phys.txt index 77b42cd..37bb7e3 100644 --- a/tutorial/data_phys.txt +++ b/tutorial/data_phys.txt @@ -1,3 +1,3 @@ -9.350308179855346680e-01 1.025287508964538574e+00 9.610000252723693848e-01 9.638672471046447754e-01 9.564632177352905273e-01 9.959252476692199707e-01 9.944854378700256348e-01 9.611032605171203613e-01 9.888327121734619141e-01 8.584257364273071289e-01 8.933780789375305176e-01 1.006429553031921387e+00 1.010190248489379883e+00 9.263296723365783691e-01 1.002564191818237305e+00 9.172156453132629395e-01 9.575013518333435059e-01 1.042762875556945801e+00 9.500958323478698730e-01 9.074602127075195312e-01 1.007630467414855957e+00 9.692767262458801270e-01 9.685617089271545410e-01 9.683685898780822754e-01 9.386281967163085938e-01 9.247136116027832031e-01 8.843302726745605469e-01 1.028766155242919922e+00 9.945049285888671875e-01 9.386556148529052734e-01 9.623722434043884277e-01 9.889562726020812988e-01 9.635400176048278809e-01 9.774058461189270020e-01 1.022206425666809082e+00 8.963786959648132324e-01 9.112113714218139648e-01 9.313991665840148926e-01 1.003261446952819824e+00 9.027391076087951660e-01 1.050069570541381836e+00 9.089088439941406250e-01 9.269468188285827637e-01 9.587357044219970703e-01 9.768312573432922363e-01 1.010041117668151855e+00 1.023543238639831543e+00 9.202455878257751465e-01 9.040115475654602051e-01 9.340862035751342773e-01 -3.071713685989379883e+00 2.948431015014648438e+00 3.193027734756469727e+00 3.239506483078002930e+00 3.318561315536499023e+00 3.246345043182373047e+00 3.320601463317871094e+00 3.117549180984497070e+00 3.209284305572509766e+00 2.914004802703857422e+00 3.422102451324462891e+00 3.414747238159179688e+00 2.790281295776367188e+00 3.031339883804321289e+00 3.336751461029052734e+00 3.226713180541992188e+00 3.072079658508300781e+00 3.148079872131347656e+00 3.080028057098388672e+00 3.221025228500366211e+00 2.975959300994873047e+00 3.168765068054199219e+00 3.025683164596557617e+00 3.206740379333496094e+00 3.299416542053222656e+00 3.295787811279296875e+00 3.338684320449829102e+00 3.037313222885131836e+00 3.275540590286254883e+00 3.298203468322753906e+00 3.517847776412963867e+00 3.249591112136840820e+00 3.568951845169067383e+00 3.215488910675048828e+00 3.248165369033813477e+00 3.050978660583496094e+00 3.282874822616577148e+00 3.260029315948486328e+00 3.277524232864379883e+00 3.336892366409301758e+00 2.848696470260620117e+00 3.150529861450195312e+00 3.053777933120727539e+00 3.215996026992797852e+00 3.261392831802368164e+00 3.467787027359008789e+00 3.026314496994018555e+00 3.272170066833496094e+00 3.294956207275390625e+00 3.338265895843505859e+00 -9.132757782936096191e-01 8.980929255485534668e-01 8.970986008644104004e-01 8.918692469596862793e-01 7.469459772109985352e-01 8.396638631820678711e-01 8.572279810905456543e-01 8.780929446220397949e-01 8.807671070098876953e-01 9.769423007965087891e-01 8.689948320388793945e-01 8.640652894973754883e-01 8.288339972496032715e-01 8.588420152664184570e-01 8.327111601829528809e-01 8.844429254531860352e-01 8.536578416824340820e-01 9.066747426986694336e-01 8.843750357627868652e-01 7.996271848678588867e-01 9.082915782928466797e-01 8.925025463104248047e-01 8.581483364105224609e-01 9.194225072860717773e-01 8.732789754867553711e-01 9.291525483131408691e-01 9.005700945854187012e-01 8.950113058090209961e-01 8.488702178001403809e-01 9.293606281280517578e-01 8.890337944030761719e-01 9.192575812339782715e-01 9.007933139801025391e-01 9.191873669624328613e-01 8.591240644454956055e-01 8.466197848320007324e-01 9.172052145004272461e-01 8.813218474388122559e-01 9.023473262786865234e-01 8.632329702377319336e-01 8.443111777305603027e-01 8.208028674125671387e-01 8.853988051414489746e-01 8.694545030593872070e-01 9.078868627548217773e-01 9.324668645858764648e-01 9.852117896080017090e-01 8.648710250854492188e-01 8.921557068824768066e-01 8.736975193023681641e-01 +1.022542227161243433e+00 8.648399523803160793e-01 8.785865330703623854e-01 9.123876878555089442e-01 9.026687139807665350e-01 9.140103159226196095e-01 9.437509232917689062e-01 1.055947742481672558e+00 1.015129950606965981e+00 9.682885010314892238e-01 9.407355398743771913e-01 9.895951843036737694e-01 9.388740164126563315e-01 8.600942391992139058e-01 9.883153399268380657e-01 8.858033335824336829e-01 9.872044318173734956e-01 1.010926342551413892e+00 9.506146028801211179e-01 8.710661965305065424e-01 9.561073908035837565e-01 9.438298629135605244e-01 9.325756580279398467e-01 9.177530263331181715e-01 9.957414912347044567e-01 9.672129717142071703e-01 8.904365146031788525e-01 9.778370093548859332e-01 1.030119394434029667e+00 9.670419059485374502e-01 1.047548151382303949e+00 8.971941327139919542e-01 9.258428348146731102e-01 9.734936388526403972e-01 1.021110018902126715e+00 1.044785598949882433e+00 9.255201122299443472e-01 8.212305172745261173e-01 1.008512106922913931e+00 1.050224547543557208e+00 9.296838973965380060e-01 9.309713174523283064e-01 8.945280823570742612e-01 9.350250795250762970e-01 8.909414115046253579e-01 9.484055888224232067e-01 9.714460248275962329e-01 9.189081943912712491e-01 1.034081529241103414e+00 9.579792412465970575e-01 +3.115354959108145305e+00 3.155056015861934071e+00 3.100304673190827920e+00 3.106075415026240183e+00 3.257002096561620430e+00 3.151432185527572205e+00 3.245522809976848500e+00 3.342064033612559548e+00 3.179656501338931296e+00 2.951313307927196483e+00 3.079383022771200906e+00 3.248299272681918204e+00 3.131159672817982020e+00 3.099106014405357712e+00 3.248473498492943534e+00 3.027027468398562959e+00 2.975278701504127543e+00 3.135655673270797195e+00 3.165492862029798626e+00 3.089689309479763057e+00 3.437991918220936860e+00 3.441357527411134765e+00 3.092025339974505016e+00 3.431696294621870535e+00 3.323244589931568793e+00 3.771379398035271002e+00 3.457190396036500690e+00 3.236226188431612805e+00 3.128945578267268157e+00 2.958624246046679840e+00 3.104689173530587709e+00 3.368150361729516540e+00 2.991678869753290559e+00 3.345958900474336861e+00 3.151230118434096283e+00 3.055918634772253561e+00 3.193804514451996646e+00 2.948388172036800814e+00 3.094456198388309875e+00 3.135799938430555489e+00 3.036715383312711847e+00 3.374327748576424479e+00 3.408641269359727666e+00 3.450912736535319603e+00 3.284394551285571229e+00 3.184996746035629656e+00 3.344201297829540120e+00 3.042564238430386592e+00 3.179508679774594171e+00 3.078774665104720309e+00 +8.718981308692004273e-01 9.394944599339473124e-01 8.109840245125449210e-01 8.674342044635487969e-01 8.377372815998970212e-01 8.398872682088270869e-01 8.297968425718058594e-01 8.240652341754092225e-01 8.844712455489686098e-01 8.566640675197852994e-01 8.887140419830728000e-01 8.328083088980998694e-01 9.145396425922491801e-01 8.873659514994393094e-01 9.163559824550125965e-01 9.277567804574815558e-01 8.745188341138259158e-01 9.518897584989245431e-01 8.609866735049213071e-01 8.489325663161423341e-01 8.350882923865023955e-01 8.832846203453300626e-01 9.525269552868314005e-01 9.261203371449789890e-01 8.671296911742534252e-01 9.300830077803385887e-01 7.839537405389693792e-01 8.329879320898978534e-01 9.042496947415379349e-01 8.380197004541889427e-01 8.695846290274023005e-01 9.338990523645058772e-01 9.273635455453536069e-01 8.862047162995956295e-01 8.387615434590119934e-01 8.775590044510869214e-01 9.041493712263121152e-01 9.031457054823411879e-01 8.729364652371947031e-01 9.084451578608553346e-01 9.399531050305723889e-01 8.859809757901814242e-01 8.896893416994795523e-01 9.245328712709195429e-01 9.587006321808235754e-01 8.704411848555125841e-01 9.345193572841126173e-01 9.452187776124448826e-01 8.547294281132449267e-01 9.590877935737233129e-01 diff --git a/tutorial/imgs/non_ident.png b/tutorial/imgs/non_ident.png new file mode 100644 index 0000000..031e37d Binary files /dev/null and b/tutorial/imgs/non_ident.png differ diff --git a/tutorial/imgs/orig_data_1.png b/tutorial/imgs/orig_data_1.png new file mode 100644 index 0000000..3b2321d Binary files /dev/null and b/tutorial/imgs/orig_data_1.png differ diff --git a/tutorial/imgs/orig_data_2.png b/tutorial/imgs/orig_data_2.png new file mode 100644 index 0000000..40a6621 Binary files /dev/null and b/tutorial/imgs/orig_data_2.png differ diff --git a/tutorial/imgs/orig_data_3.png b/tutorial/imgs/orig_data_3.png new file mode 100644 index 0000000..163818d Binary files /dev/null and b/tutorial/imgs/orig_data_3.png differ diff --git a/tutorial/imgs/orig_log.png b/tutorial/imgs/orig_log.png new file mode 100644 index 0000000..f992659 Binary files /dev/null and b/tutorial/imgs/orig_log.png differ diff --git a/tutorial/imgs/orig_params_1.png b/tutorial/imgs/orig_params_1.png new file mode 100644 index 0000000..16e3df8 Binary files /dev/null and b/tutorial/imgs/orig_params_1.png differ diff --git a/tutorial/imgs/orig_params_2.png b/tutorial/imgs/orig_params_2.png new file mode 100644 index 0000000..a103938 Binary files /dev/null and b/tutorial/imgs/orig_params_2.png differ diff --git a/tutorial/imgs/orig_params_3.png b/tutorial/imgs/orig_params_3.png new file mode 100644 index 0000000..9b89059 Binary files /dev/null and b/tutorial/imgs/orig_params_3.png differ diff --git a/tutorial/imgs/trajectories.png b/tutorial/imgs/trajectories.png new file mode 100644 index 0000000..84e114c Binary files /dev/null and b/tutorial/imgs/trajectories.png differ diff --git a/tutorial/tutorial_linfa.ipynb b/tutorial/tutorial_linfa.ipynb new file mode 100644 index 0000000..9235775 --- /dev/null +++ b/tutorial/tutorial_linfa.ipynb @@ -0,0 +1,1077 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LINFA Tutorial\n", + "This LINFA tutorial will guide through the definition of each of the quantities and functions of LINFA by applying LINFA to a practical problem set.\n", + "
\n", + "\n", + "* **What is LINFA?**\n", + "
\n", + " LINFA is a library for variational inference with normalizing flow and adaptive annealing. LINFA accommodates computationally expensive models and difficult-to-sample posterior distributions with dependent parameters.\n", + "* **Why use LINFA?**\n", + "
\n", + " Designed as a general inference engine, LINFA allows the user to define custom input transformations, computational models, surrogates, and likelihood functions which will be discussed throughout the tutorial." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tutorial outline\n", + "In this tutorial we will:\n", + "1. Define the model (problem set) to apply functions and quantities supported by LINFA.\n", + "2. Check if the model gradients computed by PyTorch match with simple finite difference-based approximations.\n", + "3. Model evaluation set up process and applications:\n", + " * Application 1: Variational inference with the original model.\n", + " * Application 2: Variational inference with neural network surrogate model.\n", + "\n", + "After going through this tutorial, users should be able to define and integrate their model with LINFA, and use the various features provided by LINFA to perform variational inference.\n", + "
\n", + "\n", + "In addition, we emphasize two special features available through LINFA:\n", + "* Adaptively trained surrogate models (NoFAS module).\n", + "* Adaptive annealing schedulers (AdaAnn module).\n", + "\n", + "We encourage the user to take advantage of such modules, especially when using physics-based models with computationally expensive evaluations. It relieves the need of heavy computations such as gradient calculation directly through the model which reduces computational cost of inference, particularly for difficult-to-sample distributions." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional Resources\n", + "\n", + "#### Background theory and examples for LINFA\n", + "* Y. Wang, F. Liu and D.E. Schiavazzi, Variational Inference with NoFAS: Normalizing Flow with Adaptive Surrogate for Computationally Expensive Models: https://www.sciencedirect.com/science/article/abs/pii/S0021999122005162\n", + "* E.R. Cobian, J.D. Hauenstein, F. Liu and D.E. Schiavazzi, AdaAnn: Adaptive Annealing Scheduler for Probability Density Approximation:\n", + "https://www.dl.begellhouse.com/journals/52034eb04b657aea,796f39cb1acf1296,6f85fe1149ff41d9.html?sgstd=1\n", + "\n", + "\n", + "#### More about LINFA library: \n", + "* LINFA library [documentation](https://linfa-vi.readthedocs.io/en/latest/index.html).\n", + "* LINFA GitHub [repository](https://github.com/desResLab/LINFA)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Import libraries ##\n", + "import os\n", + "from linfa.run_experiment import experiment\n", + "from linfa.transform import Transformation\n", + "from linfa.nofas import Surrogate\n", + "import torch\n", + "import random\n", + "import numpy as np\n", + "import pandas as pd\n", + "from matplotlib import pyplot as plt" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Problem definition\n", + "* Our physics-based model **phys** consists of a simple ballistic model. We would like to compute the quantities: \n", + " * maximum height (m) $x_{1}$,\n", + " * final landing location (m) of the object $x_{2}$,\n", + " * total flight time (s) $x_{2}$,\n", + " \n", + " from the inputs:\n", + " * starting position (m) $z_{1}$,\n", + " * initial velocity (m/s) $z_{2}$,\n", + " * angle (degrees) $z_{3}$.\n", + "
\n", + "\n", + "The model is described by the following equations\n", + "$$\n", + "x_{1} = \\frac{z_{2}^{2}\\,\\sin^{2}(z_{3})}{2\\,g},\\,\\,\n", + "x_{2} = z_{1} + \\frac{z_{2}^{2}\\,\\sin(2\\,z_{3})}{g},\\,\\,\n", + "x_{3} = \\frac{2\\,z_{2}\\,\\sin(z_{3})}{g}.\n", + "$$" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Model identifiability\n", + "\n", + "By considering fixed values for the outputs $(x_{1},x_{2},x_{3}) = (\\widetilde{x}_{1},\\widetilde{x}_{2},\\widetilde{x}_{3})$, we can perform some algebraic manipulation. For example if we derive $z_{2}$ from the equation for $x_{3}$ and we plug it back in the equation for $x_{1}$, we get the equation\n", + "$$\n", + "g\\,\\frac{\\widetilde{x}_{3}^{2}}{8} = \\widetilde{x}_{1}^{2}.\n", + "$$\n", + "\n", + "This suggests the maximum height and time of flight are, as expected, related by a deterministic condition and therefore only one of these provide an independent information for the solution of the inverse problem. \n", + "\n", + "Due to this relation, the number of observables is reduced to only two, from three inputs. This results in a non-identifiable inference task. In other words, there is an infinite number of input combinations $(z_{1},z_{2},z_{3})$ corresponding to the outputs $(\\widetilde{x}_{1},\\widetilde{x}_{2},\\widetilde{x}_{3})$. \n", + "\n", + "A graphical explanation for this lack of identifiability can be is shown in the the plot below\n", + "\n", + "
\n", + "Examples of trajectories resulting in the same landing distance and maximum height (or time of flight).\n", + "\n", + "This picture shows how the final target location at $x_{2}$ can be reached by multiple initial positions, velocities and angles. The lack of indetifiability also translates in the existence of a one-dimensional manifold of inputs that correspond to the same outputs. This manifold can be determined from the following expressions for the relations $z_1(z_{3})$ and $z_2(z_{3})$ \n", + "$$\n", + "z_{1} = \\widetilde{x}_{2} - \\frac{g\\cdot \\widetilde{x}_{3}^{2}}{2}\\cdot \\left[\\frac{\\cos(z_{3})}{\\sin(z_{3})}\\right],\\,\\,\n", + "z_{2} = \\frac{g\\cdot \\widetilde{x}_{3}}{2\\,\\sin(z_{3})}.\n", + "$$\n", + "These two curves are also plotted below.\n", + "\n", + "
\n", + "Two-dimensional projections of one-dimensional manifold where all parameters correspond to the same outputs." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Implementation as a Python class\n", + "\n", + "* We now create a new **Phys** model class, having three member functions:\n", + " * `__init__`: A constructor.\n", + " * `genDataFile`: A member function to create synthetic observations.\n", + " * `solve_t`: A function to perform forward model evaluations.\n", + " * *Please refer to the comments below for additional implementation details.* " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#### Implementation of the traditional trajectory motion physics problem ####\n", + "class Phys:\n", + " \n", + " ### Define constructor function for Phys class ###\n", + " def __init__(self):\n", + " ## Define input parameters (True value) \n", + " # input[] = [starting_position, initial_velocity, angle] = [1(m), 5(m/s), 60(degs)]\n", + " self.defParam = torch.Tensor([[1.0, 5.0, 60.0]])\n", + "\n", + " self.gConst = 9.81 # gravitational constant\n", + " self.stdRatio = 0.05 # standard deviation ratio\n", + " self.data = None # data set of model sample\n", + "\n", + " ### Define data file generator function ###\n", + " # dataSize (int): size of sample (data)\n", + " # dataFileName (String): name of the sample data file\n", + " # store (Boolean): True if user wish to store the generated data file; False otherwise.\n", + " def genDataFile(self, dataSize = 50, dataFileName=\"data_phys.txt\", store=True):\n", + " def_out = self.solve_t(self.defParam)[0]\n", + " print(def_out)\n", + " self.data = def_out + self.stdRatio * torch.abs(def_out) * torch.normal(0, 1, size=(dataSize, 3))\n", + " self.data = self.data.t().detach().numpy()\n", + " if store: np.savetxt(dataFileName, self.data)\n", + " return self.data\n", + "\n", + " ### Define data file generator function ###\n", + " # params (Tensor): input parameters storing starting position, initial velocity, and angle in corresponding order.\n", + " def solve_t(self, params):\n", + " z1, z2, z3 = torch.chunk(params, chunks=3, dim=1) # input parameters\n", + " z3 = z3 * (np.pi / 180) # convert unit from degree to radians\n", + " \n", + " ## Output value calculation\n", + " # ouput[] = [maximum_height, final_location, total_time]\n", + " x = torch.cat(( (z2 * z2 * torch.sin(z3) * torch.sin(z3)) / (2.0 * self.gConst), # x1: maxHeight\n", + " z1 + ((z2 * z2 * torch.sin(2.0 * z3)) / self.gConst), # x2: finalLocation \n", + " (2.0 * z2 * torch.sin(z3)) / self.gConst), 1) # x3: totalTime\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Generate phys sample file ##\n", + "\n", + "# Define model\n", + "model = Phys()\n", + "\n", + "# Generate Data\n", + "physData = model.genDataFile()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have our model set up, we go on to our second step, i.e., *Calculating the gradient to confirm model functionality.*" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Check for Gradient Calculation\n", + "* Prior to applying NOFAS to our Phys model, we check if the model gradient (Jacobian actually since it has multiple outputs) is correctly computed by PyTorch. \n", + "* Specifically, when the surrogate is not enabled, gradient calculation is completed straight through the model, so we want to ensure that this is correct before running some inference task.\n", + "* Here we compute each gradient using (1) Pytorch and (2) a finite difference (Euler forward differences) approximation, and compare the results provided by these two approaches." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + " #### Computing gradients through PyTorch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#### Implementation of gradient calculation using PyTorch - version 2 #### \n", + "class PytorchGrad2: \n", + " ### Define constructor function for PytorchGrad2 class ###\n", + " def __init__(self, model, transform):\n", + " # Define input parameters and enable gradient calculation\n", + " self.z = torch.Tensor([[1.0, 5.0, 60.0]])\n", + " self.z.requires_grad = True\n", + " \n", + " self.in_vals = transform.forward(self.z)\n", + "\n", + " self.out_val = model.solve_t(self.in_vals)\n", + " self.out1, self.out2, self.out3 = torch.chunk(self.out_val, chunks=3, dim=1)\n", + "\n", + " # Compute gradients using backward function for y\n", + " def back_x1(self): \n", + " self.out1.backward()\n", + " d1 = self.in_vals.grad\n", + " a, b, c = torch.chunk(d1, chunks=3, dim=1)\n", + " return [a, b, c]\n", + " \n", + " def back_x2(self): \n", + " self.out2.backward()\n", + " d2 = self.in_vals.grad\n", + " a, b, c = torch.chunk(d2, chunks=3, dim=1)\n", + " return [a, b, c]\n", + " \n", + " def back_x3(self): \n", + " self.out3.backward()\n", + " d3 = self.in_vals.grad\n", + " a, b, c = torch.chunk(d3, chunks=3, dim=1)\n", + " return [a, b, c]\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define Phys model\n", + "model = Phys()\n", + "# Set transformation information and define transforamtion\n", + "trsf_info = [['identity',0.0,0.0,0.0,0.0],\n", + " ['identity',0,0.0,0.0,0.0],\n", + " ['identity',0,0.0,0.0,0.0]]\n", + " \n", + "transform = Transformation(trsf_info)\n", + "\n", + "# List to store dx/dz values\n", + "dx_dz_pytorch2 = []\n", + "\n", + "# Define PytorchGrad object and calculate gradient\n", + "pyGrad2 = PytorchGrad2(model, transform)\n", + "dx_dz_pytorch2.append(pyGrad2.back_x1())\n", + "\n", + "pyGrad2 = PytorchGrad2(model, transform)\n", + "dx_dz_pytorch2.append(pyGrad2.back_x2())\n", + "\n", + "pyGrad2 = PytorchGrad2(model, transform)\n", + "dx_dz_pytorch2.append(pyGrad2.back_x3())\n", + "\n", + "# print(dx_dz_pytorch2) # check if output matches expectations\n", + "\n", + "# convert to pandas DataFrame for readability\n", + "jacob_mat_2 = pd.DataFrame(dx_dz_pytorch2.numpy(), columns=['dz1', 'dz2', 'dz3'])\n", + "jacob_mat_2.index = ['dx1', 'dx2', 'dx3']\n", + "jacob_mat_2" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Approximating gradients with finite differences" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "### Function that manually calculates a derivative ###\n", + "def getGrad(f_eps, f, eps):\n", + " return (f_eps - f) / (eps)\n", + "\n", + "### Function that returns a list of gradients ###\n", + "def gradList(f_eps1, f_eps2, f_eps3, f, eps): \n", + " return [getGrad(f_eps1, f, eps), getGrad(f_eps2, f, eps), getGrad(f_eps3, f, eps)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# List to store dx/dz values\n", + "dx_dz = []\n", + "dx1_dz = []\n", + "dx2_dz = []\n", + "dx3_dz = []\n", + "\n", + "# Set up parameters\n", + "eps = 1.0\n", + "z = torch.Tensor([[1.0, 5.0, 60.0]])\n", + "z_eps1 = torch.Tensor([[1.0 + eps, 5.0, 60.0]])\n", + "z_eps2 = torch.Tensor([[1.0, 5.0 + eps, 60.0]])\n", + "z_eps3 = torch.Tensor([[1.0, 5.0, 60.0 + eps]])\n", + "\n", + "x1_eps1 = model.solve_t(z_eps1)[0,0]\n", + "x1_eps2 = model.solve_t(z_eps2)[0,0]\n", + "x1_eps3 = model.solve_t(z_eps3)[0,0]\n", + "x1_eps = model.solve_t(z)[0,0]\n", + "\n", + "dx1_dz = gradList(x1_eps1, x1_eps2, x1_eps3, x1_eps, eps)\n", + "dx_dz.append(dx1_dz)\n", + "\n", + "x2_eps1 = model.solve_t(z_eps1)[0,1]\n", + "x2_eps2 = model.solve_t(z_eps2)[0,1]\n", + "x2_eps3 = model.solve_t(z_eps3)[0,1]\n", + "x2_eps = model.solve_t(z)[0,1]\n", + "\n", + "dx2_dz = gradList(x2_eps1, x2_eps2, x2_eps3, x2_eps, eps)\n", + "dx_dz.append(dx2_dz)\n", + "\n", + "x3_eps1 = model.solve_t(z_eps1)[0,2]\n", + "x3_eps2 = model.solve_t(z_eps2)[0,2]\n", + "x3_eps3 = model.solve_t(z_eps3)[0,2]\n", + "x3_eps = model.solve_t(z)[0,2]\n", + "\n", + "dx3_dz = gradList(x3_eps1, x3_eps2, x3_eps3, x3_eps, eps)\n", + "dx_dz.append(dx3_dz)\n", + "\n", + "# print(dx_dz) # check if values match expected outputs\n", + "\n", + "# convert to pandas DataFrame for readability\n", + "jacob_mat_3 = pd.DataFrame(dx_dz, columns=['dz1', 'dz2', 'dz3'])\n", + "jacob_mat_3.index = ['dx1', 'dx2', 'dx3']\n", + "jacob_mat_3" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### We verify the convergence for the finite difference approximation to the PyTorch gradient for dx2_dz3\n", + "- Note: adjust values to check convergence for other gradients of interest" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Focus: dx2_dz3\n", + "\n", + "initial_eps = 15 # Initial change of value (eps)\n", + "k = 150 # Number of iterations\n", + "dx2_dz3_list = [] # List to store results\n", + "pytorch_grad2 = -0.0445 # Pytorch gradient value\n", + "\n", + "# Calculate for dx2_dz3 as eps decreases\n", + "for t in range(1, k):\n", + " update_eps = initial_eps*(1/t) # updated eps value\n", + " z_eps3 = torch.Tensor([[1.0, 5.0, 60.0 + update_eps]]) # update z_eps3\n", + " x2_eps3 = model.solve_t(z_eps3)[0,1] # update x2_eps3\n", + " dx2_dz3_list.append(getGrad(x2_eps3, x2_eps, update_eps)) # store result to dx2_dz3_list" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Plot result to see convergence\n", + "\n", + "fig, ax = plt.subplots()\n", + "ax.plot(range(1,k), dx2_dz3_list, c = \"red\", linestyle = \"solid\", label = \"Model Gradient\")\n", + "\n", + "plt.axhline(y = pytorch_grad2, color = 'blue', linestyle = '-', label = \"Pytorch Gradient-2\")\n", + "plt.legend(loc=\"lower right\")\n", + "plt.title(\"Gradient Plot for dx2_dz3\")\n", + "plt.ylabel(\"Gradient\")\n", + "plt.xlabel(\"k-Iterations\")\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we confirmed that our model successfully computes the gradients, we go on to our third step: *Model Evaluation Set Up and Applications*" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Variational inference with full model\n", + "\n", + "#### Definition of hyperparameters\n", + "The first step is to define all options and hyperparameters for the inference task. Additional detail for each hyperparameter can be found in the [documentation](https://linfa-vi.readthedocs.io/en/latest/content/linfa_options.html) or in the definition of the [experiment](https://github.com/desResLab/LINFA/blob/master/linfa/run_experiment.py) class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Experiment Setting\n", + "exp = experiment()\n", + "exp.flow_type = 'maf' # str: Type of flow\n", + "exp.n_blocks = 5 # int: Number of layers \n", + "exp.hidden_size = 100 # int: Hidden layer size for MADE in each layer \n", + "exp.n_hidden = 1 # int: Number of hidden layers in each MADE \n", + "exp.activation_fn = 'relu' # str: Actication function used \n", + "exp.input_order = 'sequential' # str: Input order for create_mask \n", + "exp.batch_norm_order = True # boolean: Order to decide if batch_norm is used \n", + "exp.save_interval = 5000 # int: How often to sample from normalizing flow\n", + "\n", + "exp.input_size = 3 # int: Dimensionality of input \n", + "exp.batch_size = 250 # int: Number of samples generated \n", + "exp.true_data_num = 2 # double: number of true model evaluated \n", + "exp.n_iter = 25001 # int: Number of iterations \n", + "exp.lr = 0.01 # float: Learning rate \n", + "exp.lr_decay = 0.9999 # float: Learning rate decay \n", + "exp.log_interval = 100 # int: How often to show loss stat \n", + "\n", + "exp.run_nofas = False # boolean: to run experiment with nofas\n", + "exp.annealing = False # boolean: to run experiment with annealing\n", + "exp.calibrate_interval = 1000 # int: How often to update surrogate model \n", + "exp.budget = 260 # int: Total number of true model evaluation\n", + "\n", + "exp.surr_pre_it = 20000 # int: Number of pre-training iterations for surrogate model\n", + "exp.surr_upd_it = 6000 # int: Number of iterations for the surrogate model update\n", + "exp.surr_folder = \"./\"\n", + "exp.use_new_surr = True # boolean: to run experiment with nofas\n", + "\n", + "exp.results_file = 'results.txt' # str: result text file name\n", + "exp.log_file = 'log.txt' # str: log text file name\n", + "exp.samples_file = 'samples.txt' # str: sample text file name\n", + "exp.seed = random.randint(0, 10 ** 9) # int: Random seed used\n", + "exp.n_sample = 5000 # int: Total number of iterations\n", + "exp.no_cuda = True # boolean: to run experiment with NO cuda\n", + "\n", + "exp.optimizer = 'RMSprop' # str: Type of optimizer\n", + "exp.lr_scheduler = 'ExponentialLR' # str: Type of scheduler\n", + "\n", + "exp.device = torch.device('cuda:0' if torch.cuda.is_available() and not exp.no_cuda else 'cpu')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Define the transformation \n", + "Now we define the trasformation of parameters and initialize the " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define transformation based on normalization rate\n", + "trsf_info = [['identity',0.0,0.0,0.0,0.0],\n", + " ['identity',0.0,0.0,0.0,0.0],\n", + " ['linear',-3,3,30.0,80.0]]\n", + "trsf = Transformation(trsf_info) \n", + "exp.transform = trsf" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Model and surrogate definition\n", + "We create an instance of the **Phys** model and assign `None` to the surrogate." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define model\n", + "model = Phys()\n", + "exp.model = model\n", + "\n", + "# Get data\n", + "model.data = np.loadtxt('./data_phys.txt')\n", + "\n", + "# Run experiment without surrogate\n", + "exp.surrogate = None" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Log-likelihood definiton\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Define log density\n", + "# x: original, untransformed inputs\n", + "# model: our model\n", + "# transform: our transformation \n", + "def log_density(x, model, surrogate, transform):\n", + " # x contains the original, untransformed inputs\n", + " np.savetxt(exp.output_dir + '/' + exp.name + '_x', x.detach().numpy(), newline=\"\\n\")\n", + " # Compute transformation log Jacobian\n", + " adjust = transform.compute_log_jacob_func(x)\n", + "\n", + " batch_size = x.size(0)\n", + " # Get the absolute values of the standard deviations\n", + " stds = torch.abs(model.solve_t(model.defParam)) * model.stdRatio\n", + " Data = torch.tensor(model.data).to(exp.device)\n", + " \n", + " # Check for surrogate\n", + " if surrogate:\n", + " modelOut = exp.surrogate.forward(x)\n", + " else:\n", + " modelOut = model.solve_t(transform.forward(x))\n", + "\n", + " # Eval LL\n", + " ll1 = -0.5 * np.prod(model.data.shape) * np.log(2.0 * np.pi)\n", + " ll2 = (-0.5 * model.data.shape[1] * torch.log(torch.prod(stds))).item()\n", + " ll3 = 0.0\n", + " for i in range(3):\n", + " ll3 += - 0.5 * torch.sum(((modelOut[:, i].unsqueeze(1) - Data[i, :].unsqueeze(0)) / stds[0, i]) ** 2, dim=1)\n", + " negLL = -(ll1 + ll2 + ll3)\n", + " res = -negLL.reshape(x.size(0), 1) + adjust\n", + " np.savetxt(exp.output_dir + '/' + exp.name + '_res', res.detach().numpy(), newline=\"\\n\")\n", + " return res" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Launch inference task" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Run \n", + "print('')\n", + "print('--- Temporary TEST: Physics Example - without NOFAS')\n", + "print('')\n", + "\n", + "print('--- Running on device: '+ str(exp.device))\n", + "print('')\n", + "\n", + "# Experiment Setting\n", + "exp.name = \"phys_nofasFree\" # str: Name of experiment\n", + "exp.output_dir = './' + exp.name # str: output directory location\n", + "\n", + "# Assign logdensity\n", + "exp.model_logdensity = lambda x: log_density(x, model, exp.surrogate, trsf)\n", + "\n", + "# Run VI\n", + "exp.run()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that the model evaluation has been successfully completed by checking at the newly created **phys_nofasFree** folder in our current directory.\n", + "Note also that LINFA supports a post processing script to plot all results which includes plots of log loss, parameter estimation, and estimated output data.\n", + "\n", + "The according code line is: `python -m linfa.plot_res -n phys_nofasFree -i 25000 -f phys_nofasFree`\n", + " \n", + "> However, even with our simple model, computing the gradients and confirming model functionality everytime before applying the model evaluation is time consuming and when it comes with evaluting even more complex models, the computational costs will be exponential and will likely result in intractable inference.
\n", + "In such cases, LINFA enables the construction of the adaptively trained surrogate model which resolves such concerns regarding the computational cost of inference. By utilizing the surrogate model, gradient computation is executed by the surrogate model which eliminates the need to manually check for gradient calcultation.
\n", + "In addition, LINFA provides an adaptive annealing scheduler which allows easier sampling from complicated densities.
\n", + "\n", + "> Accordingly, we will specifically observe how the adaptively trained surrogate model relieves compuatational cost while ensuring inference in our last step:
\n", + "*Applying our model including the Surrogate model.*" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Results\n", + "The results below are obtained from variational inference with the full model. The posterior distribution is concentrated around the one-dimensional manifold where the input parameters are not identifiable. \n", + "\n", + "
\n", + "Loss converge profiles for variational inference with the full model.
\n", + "
\n", + "\n", + "\n", + "\n", + "
\n", + "Posterior distribution of the input parameters.
\n", + "
\n", + "\n", + "\n", + "\n", + "
\n", + "Comparison between posterior predictive distribution and available observations.
\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import IFrame\n", + "IFrame(\"./samples/simple3.pdf\", width=600, height=300)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Phys Model with the Adaptively Trained Surrogate Model\n", + "Note: this block of code takes a while." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Experiment Setting\n", + "exp.name = \"phys\" # str: Name of experiment\n", + "exp.output_dir = './' + exp.name # str: output directory location\n", + "\n", + "# Define model\n", + "model = Phys()\n", + "exp.model = model\n", + "\n", + "# Get data\n", + "model.data = np.loadtxt('./data_phys.txt')\n", + "\n", + "# Define surrogate\n", + "exp.surrogate = Surrogate(exp.name, lambda x: model.solve_t(trsf.forward(x)), exp.input_size, 3, \n", + " model_folder=exp.surr_folder, limits=torch.Tensor([[0, 2], [0, 10], [-3, 3]]), \n", + " memory_len=20, device=exp.device)\n", + "surr_filename = exp.surr_folder + exp.name\n", + "if exp.use_new_surr or (not os.path.isfile(surr_filename + \".sur\")) or (not os.path.isfile(surr_filename + \".npz\")):\n", + " print(\"Warning: Surrogate model files: {0}.npz and {0}.npz could not be found. \".format(surr_filename))\n", + " # 4 samples for each dimension: pre-grid size = 16\n", + "# exp.surrogate.gen_grid(gridnum=4)\n", + " exp.surrogate.gen_grid(gridnum=3)\n", + " exp.surrogate.pre_train(exp.surr_pre_it, 0.03, 0.9999, 500, store=True)\n", + "# Load the surrogate\n", + "exp.surrogate.surrogate_load()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Run \n", + "print('')\n", + "print('--- Temporary TEST: Physics Example - with NOFAS')\n", + "print('')\n", + "\n", + "print('--- Running on device: '+ str(exp.device))\n", + "print('')\n", + "\n", + "# Assign logdensity\n", + "exp.model_logdensity = lambda x: log_density(x, model, exp.surrogate, trsf)\n", + "\n", + "# Run VI\n", + "exp.run()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> Notice that the model evaluation has been successfully completed by checking at the newly created **phys** folder in our current directory.
\n", + "> Note that LINFA supports a post processing script to plot all results which includes plots of log loss, parameter estimation, and estimated output data.
\n", + "The according code line is: `python -m linfa.plot_res -n phys -i 25000 -f phys`\n", + "
\n", + "\n", + " \n", + "> Now, to compare the functionalities of the surrogate model, we observe the plots generated by the LINFA library.\n", + "
\n", + "Note that the generated plots below are converted to png format." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# import libraries\n", + "import cv2\n", + "from matplotlib import pyplot as plt\n", + "from pdf2image import convert_from_path\n", + "from IPython.display import IFrame\n", + "\n", + "## Phys. without Surrogate\n", + "# 1-a) output evaluation plots - without surrogate\n", + "output1_free = convert_from_path('./phys_nofasFree/data_plot_phys_nofasFree_25000_0_1.pdf')\n", + "for page in output1_free:\n", + " page.save('data_plot_phys_nofasFree_25000_0_1.png', 'PNG')\n", + " \n", + "output2_free = convert_from_path('./phys_nofasFree/data_plot_phys_nofasFree_25000_0_2.pdf')\n", + "for page in output2_free:\n", + " page.save('data_plot_phys_nofasFree_25000_0_2.png', 'PNG')\n", + " \n", + "output3_free = convert_from_path('./phys_nofasFree/data_plot_phys_nofasFree_25000_1_2.pdf')\n", + "for page in output3_free:\n", + " page.save('data_plot_phys_nofasFree_25000_1_2.png', 'PNG')\n", + " \n", + "# 1-b) output evaluation plots - with surrogate\n", + "output1 = convert_from_path('./phys/data_plot_phys_25000_0_1.pdf')\n", + "for page in output1:\n", + " page.save('data_plot_phys_25000_0_1.png', 'PNG')\n", + " \n", + "output2 = convert_from_path('./phys/data_plot_phys_25000_0_2.pdf')\n", + "for page in output2:\n", + " page.save('data_plot_phys_25000_0_2.png', 'PNG')\n", + " \n", + "output3 = convert_from_path('./phys/data_plot_phys_25000_1_2.pdf')\n", + "for page in output3:\n", + " page.save('data_plot_phys_25000_1_2.png', 'PNG')\n", + " \n", + "# 2-a) log-loss plots - without surrogate\n", + "log_free = convert_from_path('./phys_nofasFree/log_plot.pdf')\n", + "for page in log_free:\n", + " page.save('log_plot_nofasFree.png', 'PNG')\n", + "\n", + "# 2-b) log-loss plots - with surrogate\n", + "log_free = convert_from_path('./phys/log_plot.pdf')\n", + "for page in log_free:\n", + " page.save('log_plot_phys.png', 'PNG')\n", + " \n", + "# 3-a) parameter evaluation plots - without surrogate\n", + "param1_free = convert_from_path('./phys_nofasFree/params_plot_phys_nofasFree_25000_0_1.pdf')\n", + "for page in param1_free:\n", + " page.save('params_plot_phys_nofasFree_25000_0_1.png', 'PNG') \n", + " \n", + "param2_free = convert_from_path('./phys_nofasFree/params_plot_phys_nofasFree_25000_0_2.pdf')\n", + "for page in param2_free:\n", + " page.save('params_plot_phys_nofasFree_25000_0_2.png', 'PNG') \n", + " \n", + "param3_free = convert_from_path('./phys_nofasFree/params_plot_phys_nofasFree_25000_1_2.pdf')\n", + "for page in param3_free:\n", + " page.save('params_plot_phys_nofasFree_25000_1_2.png', 'PNG') \n", + " \n", + "# 3-b) parameter evaluation plots - with surrogate\n", + "param1 = convert_from_path('./phys/params_plot_phys_25000_0_1.pdf')\n", + "for page in param1:\n", + " page.save('params_plot_phys_25000_0_1.png', 'PNG') \n", + " \n", + "param2 = convert_from_path('./phys/params_plot_phys_25000_0_2.pdf')\n", + "for page in param2:\n", + " page.save('params_plot_phys_25000_0_2.png', 'PNG') \n", + " \n", + "param3 = convert_from_path('./phys/params_plot_phys_25000_1_2.pdf')\n", + "for page in param3:\n", + " page.save('params_plot_phys_25000_1_2.png', 'PNG') " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## We first take a look to check if the images are well converted to png\n", + "\n", + "# import package\n", + "from IPython.display import Image, display\n", + "\n", + "# display png image\n", + "img1 = Image(filename='output1_free.png')\n", + "img2 = Image(filename='output2_free.png')\n", + "img3 = Image(filename='output3_free.png')\n", + "# display(img1, img2, img3)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> Now we compare the generated plots of log-loss, output data and parameter estimation." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1) *log-loss* plots" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## log-loss plots\n", + "\n", + "# create figure\n", + "fig = plt.figure(figsize=(10, 7))\n", + "\n", + "# setting row and column variables\n", + "rows = 1\n", + "columns = 2\n", + "\n", + "# reading images\n", + "logLoss_phys = cv2.imread('log_plot_phys.png')\n", + "logLoss_physFree = cv2.imread('log_plot_nofasFree.png')\n", + "\n", + "# Adds a subplot at the 1st position\n", + "fig.add_subplot(rows, columns, 1)\n", + "\n", + "# Add subplots corresponding to Phys with Surrogate\n", + "plt.imshow(logLoss_phys)\n", + "plt.axis('off')\n", + "plt.title(\"Phys. with Surrogate\")\n", + "\n", + "# Add subplots corresponding to Phys without Surrogate\n", + "fig.add_subplot(rows, columns, 2)\n", + "# showing image\n", + "plt.imshow(logLoss_physFree)\n", + "plt.axis('off')\n", + "plt.title(\"Phys. without Surrogate\")\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> While both of the log loss plots show convergence to a lower log loss value, notice that the log loss plot of the Phys. model that utilized the Surrogate model had a lower log loss value than the case when the Surrogate was not used.\n", + "
" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2) *ouput estimation* plots" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## output plots\n", + "\n", + "# create figure\n", + "fig = plt.figure(figsize=(10, 7))\n", + "\n", + "# setting row and column variables\n", + "rows = 3\n", + "columns = 2\n", + "\n", + "# reading images\n", + "phys1 = cv2.imread('data_plot_phys_25000_0_1.png')\n", + "phys2 = cv2.imread('data_plot_phys_25000_0_2.png')\n", + "phys3 = cv2.imread('data_plot_phys_25000_1_2.png')\n", + "# print('Image Width is',Image1.shape[1]) #327\n", + "# print('Image Height is',Image1.shape[0]) #259\n", + "# Image1 = cv2.resize(Image1, (400,300))\n", + "\n", + "physFree1 = cv2.imread('data_plot_phys_nofasFree_25000_0_1.png')\n", + "physFree2 = cv2.imread('data_plot_phys_nofasFree_25000_0_2.png')\n", + "physFree3 = cv2.imread('data_plot_phys_nofasFree_25000_1_2.png')\n", + "\n", + "# Add subplots corresponding to Phys with Surrogates\n", + "fig.add_subplot(rows, columns, 1)\n", + "plt.imshow(phys1)\n", + "plt.axis('off')\n", + "plt.title(\"Phys. with Surrogate\")\n", + "fig.add_subplot(rows, columns, 3)\n", + "plt.imshow(phys2)\n", + "plt.axis('off')\n", + "fig.add_subplot(rows, columns, 5)\n", + "plt.imshow(phys3)\n", + "plt.axis('off')\n", + "\n", + "# Add subplots corresponding to Phys without Surrogate\n", + "fig.add_subplot(rows, columns, 2)\n", + "plt.imshow(physFree1)\n", + "plt.axis('off')\n", + "plt.title(\"Phys. with Surrogate\")\n", + "fig.add_subplot(rows, columns, 4)\n", + "plt.imshow(physFree2)\n", + "plt.axis('off')\n", + "fig.add_subplot(rows, columns, 6)\n", + "plt.imshow(physFree3)\n", + "plt.axis('off')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> While both of the output plots have meaningful results where the samples (blue dots) are distributed within the estimated region (red cluster), notice that the samples are much more clustered within the estimated area from the Phys. model that utilized the Surrogate model than the case where the Surrogate model was not used.\n", + "
" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3) *parameter estimation* plots" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## parameter estimation plots\n", + "\n", + "# create figure\n", + "fig = plt.figure(figsize=(10, 7))\n", + "\n", + "# setting row and column variables\n", + "rows = 3\n", + "columns = 2\n", + "\n", + "# reading images\n", + "phys1 = cv2.imread('params_plot_phys_25000_0_1.png')\n", + "phys2 = cv2.imread('params_plot_phys_25000_0_2.png')\n", + "phys3 = cv2.imread('params_plot_phys_25000_1_2.png')\n", + "\n", + "physFree1 = cv2.imread('params_plot_phys_nofasFree_25000_0_1.png')\n", + "physFree2 = cv2.imread('params_plot_phys_nofasFree_25000_0_2.png')\n", + "physFree3 = cv2.imread('params_plot_phys_nofasFree_25000_1_2.png')\n", + "\n", + "# Add subplots corresponding to Phys with Surrogates\n", + "fig.add_subplot(rows, columns, 1)\n", + "plt.imshow(phys1)\n", + "plt.axis('off')\n", + "plt.title(\"Phys. with Surrogate\")\n", + "fig.add_subplot(rows, columns, 3)\n", + "plt.imshow(phys2)\n", + "plt.axis('off')\n", + "fig.add_subplot(rows, columns, 5)\n", + "plt.imshow(phys3)\n", + "plt.axis('off')\n", + "\n", + "# Add subplots corresponding to Phys without Surrogate\n", + "fig.add_subplot(rows, columns, 2)\n", + "plt.imshow(physFree1)\n", + "plt.axis('off')\n", + "plt.title(\"Phys. without Surrogate\")\n", + "fig.add_subplot(rows, columns, 4)\n", + "plt.imshow(physFree2)\n", + "plt.axis('off')\n", + "fig.add_subplot(rows, columns, 6)\n", + "plt.imshow(physFree3)\n", + "plt.axis('off')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> Recall our input values (parameters) were [1,5,60]. Notice that both cases successfully returns parameter estimations that are fairly accurate. Specifically, we observe that when the Surrogate model is used, the accuracy of the parameter estimation increases.\n", + "
\n", + "\n", + "> Based on the generated plots, we've observed the physical benefits of the Surrogate model and the functionality of the Phys. model estimation utilizing LINFA. \n", + "> Moreover, LINFA supports various model types. For more information, please refer to our paper **[need to put our paper url!!!]** *Appendix B. Detailed numerical benchmarks* where examples utilizing the LINFA library is introduced." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}