Skip to content

Commit

Permalink
Removed noise as parameter, now regularization is automatic
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudioZeni committed May 17, 2021
1 parent 41b05c7 commit 1315ef9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 53 deletions.
58 changes: 15 additions & 43 deletions examples/Linear Potential.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@
},
{
"cell_type": "code",
"execution_count": 163,
"execution_count": 321,
"metadata": {},
"outputs": [],
"source": [
"ntr = 50\n",
"nval = 50\n",
"ntr = 100\n",
"nval = 100\n",
"\n",
"train_structures = ut.load_structures(\"data/Si/train_trajectory.json\")\n",
"val_structures = ut.load_structures(\"data/Si/validation_trajectory.json\")\n",
Expand Down Expand Up @@ -79,15 +79,15 @@
},
{
"cell_type": "code",
"execution_count": 164,
"execution_count": 329,
"metadata": {},
"outputs": [],
"source": [
"ns = 4\n",
"ls = 4\n",
"r_cut = 5.2\n",
"\n",
"pot = lp.LinearPotential('3', ns, ls, r_cut, species, True, basis = 'bessel')"
"pot = lp.LinearPotential('3', ns, ls, r_cut, species, True, basis = 'chebyshev')"
]
},
{
Expand All @@ -99,11 +99,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 330,
"metadata": {},
"outputs": [],
"source": [
"pot.fit(X, ncores = 1, compute_forces = True)"
"pot.fit(X, ncores = 1, compute_forces=True)"
]
},
{
Expand All @@ -115,45 +115,17 @@
},
{
"cell_type": "code",
"execution_count": 176,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-05-17 11:14:18,616\tINFO services.py:1090 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265\u001b[39m\u001b[22m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"RMSE ENERGY: 57.79 meV/atom\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbCUlEQVR4nO3de3RV5bnv8e9jCLjcbY0IVgkoVC1eSo+hKUdF6rFi2bVYgbZWu7e3aum5cIbsoYjYcUbt6bGgbMRdxx4d3qjSUi9bY7Yge8QL0qpcAxGDcqhIQQyIiTVHLQFD8pw/shKyLklWsuZac11+nzEyRuZlrfkyS388vvN932nujoiIFL6jwm6AiIhkhwJfRKRIKPBFRIqEAl9EpEgo8EVEisSgsBvQm2HDhvno0aPDboaISN7YtGlTk7sPT3YspwN/9OjR1NbWht0MEZG8YWa7ezqmLh0RkSKhwBcRKRIKfBGRIqHAFxEpEgp8EZEikdOjdEREikl1XQMLa7azt7mFEWUR5kwZy7SK8sC+X4EvIpIDqusamFdVT0trGwANzS3Mq6oHCCz01aUjIpIDFtZs7wr7Ti2tbSys2R7YNRT4IiI5YG9zS7/2D4QCX0QkB4woiwDwSd1KWnZuStgfBPXhi4jkgGvHHcNPp17ctX3K3BVESkuYM2VsYNdQhS8iErJrrrmGn049r2t71P9cRnlZhPkzxmmUjohIIairq2P8+PFd2/fffz8zZ87M2PUU+CIiWdbe3s4FF1zA2rVrASgrK2Pv3r1EIsH11yejLh0RkSyqqamhpKSkK+yXL1/ORx99lPGwB1X4IiJZcfDgQUaNGkVTUxMAlZWVrFu3jpKSkqy1IZAK38xuNjM3s2E9HD/ZzJ43s21m9paZjQ7iuiIi+WDJkiVEIpGusN+wYQMbN27MathDABW+mY0CvgW828tpS4E73f0FM/sc0J7udUVEct1HH33E0KFDu7avvPJK/vCHP2BmobQniAp/MXAr4MkOmtlZwCB3fwHA3T919wMBXFdEJGXVdQ1MXLCKMbc9x8QFq6iua8jo9X7xi1/EhP0777zDY489FlrYQ5oVvpldDjS4+5Ze/hBfBprNrAoYA7wI3ObubT19QEQkSNlYmKzT7t27GT16dNf27bffzp133hnoNQaqz8A3sxeBE5Mc+hlwOx3dOX1dYxJQQUe3zxPAdcDDPVxvJjAT4OSTT+6reSIifeptYbIgA//aa69l6dKlXduNjY0MG5b00WYo+uzScffJ7v6V+B9gJx0V+xYz2wWMBDabWfw/Du8Br7v7Tnc/DFQD4+mBuz/g7pXuXjl8+PAB/rFERI7I9MJk69atw8y6wv7+++/H3XMq7CGNLh13rwdO6NyOhn6luzfFnboRKDOz4e7eCHwTqB3odUVE+mtEWYSGJOGe7sJkhw8fprS0tGu7ZMjfsezl1/nheael9b2ZkpGJV2ZWaWYPAUT76m8BXjKzesCABzNxXRGRZOZMGUukNHYI5EAXJut8+HvshBkxYf+F865g5OwnuGPljow/EB6owCZeufvobr/XAjd2234B+GpQ1xIR6Y/Ofvp0Xx9YXdfArX9Yx9v//P2Y/SffUo2VdMRpJp4NBEUzbUWkKEyrKE87hK/9/nf4eOeWru2hl/w3Pj/+OwnnBfnSkiAp8EVE+rBz505OPfXUmH0n37q8xzH1Qb60JEgKfBGRXgwaNIi2tiNDOk+44pdExlT0eH7QLy0JklbLFBFJ4pVXXsHMYsL+mc3vMfTLlTHnlZYYZZFSDDLy0pIgqcIXEYkT31Wzbds2zjjjjK7tdB/+hkWBLyIS9cgjj3D99dd3bVdUVLB58+aYc4J4+BsWBb6IFL22tjYGDYqNw6amJo4//viQWpQZ6sMXkaJ2yy23xIT9zJkzcfeCC3tQhS8iRerjjz/m2GOPjdl36NAhBg8eHFKLMk8VvogUnYsuuigm7O+77z7cvaDDHlThi0gRWbVqFRdffHHMvvb29lBfSpJNqvBFpCiYWUzYP//887h70YQ9KPBFpMAtW7YsIdTdnUsuuSSkFoVHXToiUrDig/6VV17hggsuCKk14VOFLyIFZ/bs2Umr+mIOe1CFLyIFpLW1NWGkzZ49exg5cmRILcotCnwRKQjxFf1JJ53E3r17Q2pNblKXjojktV27diWEfUtLi8I+CQW+iOQtM2PMmDFd22eeeSbuztFHHx1iq3KXAl9E8s7KlSsTqvr29nbeeuutkFqUH9SHLyJ5JT7ob7jhBh566KGUPltd15C3a9kHQRW+iOSFefPmJR1q2Z+wn1dVT0NzCw40NLcwr6qe6rqGDLQ2NynwRSTnmRkLFizo2l62bBnu3q/vWFiznZbWtph9La1tLKzZHkgb84G6dEQkZyVb56a/Qd9pb3NLv/YXIlX4IpJzDh48mBD2W7du7TXsq+samLhgFWNue46JC1YldNWMKIsk/VxP+wuRAl9EcoqZEYnEhrC7c/bZZ/f4mVT65+dMGUuktCTmc5HSEuZMGRto+3OZAl9EcsLbb7+dUNU3Nzen1IWTSv/8tIpy5s8YR3lZBAPKyyLMnzGuqEbpqA9fREKXbl99qv3z0yrKiyrg46nCF5HQPPbYY0knUPX3waz651OjCl9EQhEf9N/4xjf44x//OKDvmjNlLPOq6mO6deL754t90hWowheRLLv66quTTqAaaNhD3/3zmnTVQRW+iGRNfNAvXryY2bNnB/LdvfXP9/ZQt5iqfAW+iGRckBOoBkKTrjqoS0dEMibZBKqNGzdmNexBD3U7KfBFJCN6mkBVWVmZ9bZo0lUHBb6IBCrZBKqPPvoo61V9d5p01UF9+CISmKD76oMcSlnsk65AFb6IBODxxx8PZAJVdxpKGTwFvoikxcy46qqrurYnTZqEuyet9vtD69cHL7DAN7ObzczNbFgPx+82szfNbJuZ/drS/dsgIqG65pprkk6g+tOf/hTI92soZfACCXwzGwV8C3i3h+PnAxOBrwJfAb4OXBjEtUUk+8yM3/3ud13b99xzT+APZTWUMnhBPbRdDNwK/HsPxx04GhgMGFAK7A/o2iKSJdmcQJXK+jjSP2lX+GZ2OdDg7lt6Osfd1wIvA/uiPzXuvq2H75tpZrVmVtvY2Jhu80QkAMkmUG3YsCGjQy01lDJ4KVX4ZvYicGKSQz8DbqejO6e3z58GnAmMjO56wcwmufsr8ee6+wPAAwCVlZXhDdwVESDcZRE0lDJYKVX47j7Z3b8S/wPsBMYAW8xsFx2BvtnM4v9xmA6sc/dP3f1T4D+A84L7Y4hI0Hbs2JFzE6gkPWn14bt7PXBC53Y09CvdvSnu1HeBn5jZfDr68C8E7k3n2iKSOWEvdiaZkbFx+GZWaWYPRTefAt4B6oEtwBZ3X56pa4vIwDzxxBOBT6CS3BHo0gruPrrb77XAjdHf24CfBnktEQlWfNCff/75vPbaayG1RjJBM21Fitx1112XdAKVwr7wKPBFipiZ8eijj3ZtL1q0SN03BUyrZYoUIT2ULU6q8EWKyKFDhxLCfv369Qr7IqEKX6RIqKoXVfgiBU4TqKSTKnyRAqaqXrpThS9SgJ588klNoJIEqvBFCkx80J933nmsWbMmpNZILlGFL1Igrr/++qQTqBT20kmBL1IAzIxHHnmka3vhwoXqvpEE6tIRyWMlJSW0t7fH7FPQS09U4Yvkoc4JVN3Dft26dQp76ZUqfJE8o6GWMlCq8EXyxDvvvJMQ9n/9618V9pIyVfgieUBVvQRBFb5IDtMEKgmSKnyRHBUf9Oeeey5r164NqTVSCFThi+SYH//4x0knUCnsJV0KfJEcYmb89re/7dq+++671X0jgVGXjkgOGDRoEG1tbTH7FPQSNFX4IiHqnEDVPezXrl2rsJeMUIUvEhINtZRsU4UvkmWaQCVhUYUvkkWq6oNVXdfAwprt7G1uYURZhDlTxjKtojzsZuUsVfgiWfDUU09pAlXAqusamFdVT0NzCw40NLcwr6qe6rqGsJuWsxT4IhlmZvzgBz/o2p4wYQLunrTal9QtrNlOS2vsyKaW1jYW1mwPqUW5T4EvkiE33HBD0glU69evD6lFhWVvc0u/9osCXyQjzIwlS5Z0bd91113qvgnYiLJIv/aLHtqKBKq0tJTDhw/H7FPQZ8acKWOZV1Uf060TKS1hzpSxIbYqt6nCFwnAZ599hpnFhP2aNWsU9hk0raKc+TPGUV4WwYDysgjzZ4zTKJ1eqMIXSZOGWoZnWkW5Ar4fVOGLDNDOnTsTwv7DDz9U2EvOUoUvMgCq6iUfqcIX6Yenn35aE6gkb6nCF0lRfNB//etfZ8OGDSG1RqT/VOGL9OHGG29MOoFKYS/5RoEv0gsz4+GHH+7a1gQqyWdpdemY2R3AT4DG6K7b3X1lkvP+HvgXoAR4yN0XpHNdkUwbMmQIn332Wcw+Bb3kuyAq/MXufk70J1nYlwD/CnwbOAu4yszOCuC6IoHrnEDVPew1gUoKRTYe2k4Adrj7TgAzexy4HHgrC9cWSZmGWkqhC6LCn2Vmb5jZEjM7LsnxcmBPt+33ovuSMrOZZlZrZrWNjY09nSYSmL/85S+aQCVFoc/AN7MXzWxrkp/Lgd8ApwLnAPuARek2yN0fcPdKd68cPnx4ul8n0isz40tf+lLMPndn6NChIbVIJHP67NJx98mpfJGZPQisSHKoARjVbXtkdJ9IaKqqqvje974Xs6+9vV0vJZGClu4onZPcfV90czqwNclpG4HTzWwMHUF/JfCjdK4rko74UP/a175GbW1tSK0RyZ50+/DvNrN6M3sDuAj4JwAzG2FmKwHc/TAwC6gBtgFPuvubaV5XpN9mzpyZdAKVwl6KRVoVvrtf3cP+vcCl3bZXAglDNkWyJT7oFyxYwNy5c0NqjUg4tJaOFLRIJMLBgwdj9mn0jRQrLa0gBalzAlX3sH/ttdcU9lLUVOFLwdEEKpHkVOFLwdi1a1dC2Dc1NSnsRaJU4UtBUFUv0jdV+JLXnnnmGb2BSiRFqvAlb8UH/fjx49m0aVNIrRHJfarwJe9cdtllSSdQKexFeqfAl7xiZqxYcWTJpl/96lfqvhFJkbp0JC/ooaxI+lThS047dOhQQtjX1NQo7EUGQBW+5CxV9SLBUoUvOefPf/5zQtjv379fYS+SJlX4klNU1Ytkjip8yQmPPvpoQti3tbUp7EUCpApfQhcf9Mcccwx/+9vfQmqNSOFShS+h+e53v5t0ApXCXiQzFPgSCjNj+fLlXds33XSTum9EMkxdOpJVeigrEh5V+JIVnW+g6k4TqESySxW+ZJyqepHcoApfMubtt99OCPv3339fYS8SElX4khGq6kVyjyp8CdTSpUszPoGquq6BiQtWMea255i4YBXVdQ2BfbdIIVOFL4GJD/pIJMKBAwcCvUZ1XQPzquppaW0DoKG5hXlV9QBMqygP9FoihUYVvqRt2rRpSSdQBR32AAtrtneFfaeW1jYW1mwP/FoihUYVvqQlPug//7XLGPf92VTXNWSk4t7b3NKv/SJyhAJfBiTZQ9lT5na8ejCT3SwjyiI0JAn3EWWRQK8jUojUpSP9kmwC1ZnXze8K+06Z6maZM2UskdKSmH2R0hLmTBkb+LVECo0CX1JmZgwZMiRmn7vT8sVxSc/PRDfLtIpy5s8YR3lZBAPKyyLMnzFOD2xFUqAuHenTjh07OP3002P27du3jxNPPBHIfjfLtIpyBbzIAKjCl16ZWULYu3tX2IO6WUTyhQJfkvr973+f8gQqdbOI5Ad16UiC+KAfMmQIBw8e7PUz6mYRyX2q8KXL9OnTk06g6ivsRSQ/KPAF6Kjqq6uru7ZnzZqlxc5ECoy6dIqcVrUUKR6q8ItUa2trQtivXLlSYS9SwNKq8M3sDuAnQGN01+3uvjLunFHAUuCLgAMPuPu/pHNd6Vl1XQMLa7azt7mFEWUR5kwZm/AwVVW9SHEKosJf7O7nRH9WJjl+GLjZ3c8CzgX+h5mdFcB1JU7n0sENzS04R9a06VwvfseOHQlhv2/fPoW9SJHIeB++u+8D9kV//8TMtgHlwFuZvnax6W3p4OnjRyacr6AXKS5BVPizzOwNM1tiZsf1dqKZjQYqgPW9nDPTzGrNrLaxsbGn0ySJZGvXfPrmy6yZd3HMvqDfQCUi+aHPwDezF81sa5Kfy4HfAKcC59BRxS/q5Xs+BzwNzHb3j3s6z90fcPdKd68cPnx4f/88RS1+7Zrdd03lwxVH/icZPHgw7s5RR+lZvUgx6rNLx90np/JFZvYgsKKHY6V0hP0yd6/qVwslZXOmjGVeVT3vPvlLDvx5TcwxVfQiklapZ2YndducDmxNco4BDwPb3P2edK4nvZtWUc7//T/fjgn7b19xrcJeRID0H9rebWbn0DHcchfwUwAzGwE85O6XAhOBq4F6M3s9+rmE4ZuSnuOOO47m5uaYfQp6EekurcB396t72L8XuDT6+6tA4sBvCURrayuDBw+O2bd69WouvPDCkFokIrlKSyvkMU2gEpH+0HCNPLRnz56EsP/ggw8U9iLSK1X4OS5+qYT4MfWgql5EUqPAz2GdSyW0tLZxYMcG1jz9v2OOt7W1aUy9iKRMgZ/DOpdK2H3X1Jj9kRNO4cD+XeE0SkTylgI/h9U/dS+fbFoes++UuSs05ElEBkSBn6PiH8oeO/FHlF3wIyBxCQURkVQo8HPMsGHD+PDDD2P2nTL3yIoVkdIS5kwZm+1miUgB0BO/HHH48GHMLCbsV69ezTOb36O8LIIB5WUR5s8Yl/BCExGRVKjCzwF9TaBSwItIEFThhyjZBKr9+/drXL2IZIQq/JBoWQQRyTZV+Fn23HPPJYS93kAlItmgCj9D4pdEmDNlbMJ7Zc844wy2bdsWUgtFpNgo8DOguq6BOf+2hdb2jqq9/ulfM33ev8eco4peRLJNXToZcMezb3aF/e67pvJJ7ZGw//nPf66wF5FQqMLPgOaWVvbc94+0H2iO2X/K3BXcccd3wmmUiBQ9BX7ADh8+nLDY2Rev+hVHn/zVkFokItJBgZ+G6roG7nj2TZpbWgESgh5il0U47pjSrLVNRCSeAn+Auj+YPfxxEw2/uS7m+OibluFHH9u1XVpi/Pyys7PcShGRIxT4A7SwZjut7Z60qj9//kvMmTI2YVimlkgQkTAp8Ado59Za3l82N2bfybc+i9lR7G1uYVpFuQJeRHKKAj8F//DgWl57569d2wlvoDrtP3PC9/5X17bWqxeRXKTA70P3sP9/66toXr0k5nj3h7LQ0Vev9epFJBdp4lUfOsN+911TY8J+6CX/lWc2v0dZ5MjIm+OOKWXh9/+TunJEJCepwu/D/sdv5+DuN2L2dVb16qcXkXyiwO9BW1sbgwbF3p4Tr7mHISd9OaQWiYikR4GfRLK16uP76ieeOjRbzRERCYT68LvZv39/Qtg3NTXxowfWxOybeOpQlv3kvGw2TUQkbarwo3p7A5XCXUQKQdFX+K+++qreQCUiRaGoA9/MmDRpUtf21KlTcXeOOqqob4uIFKiiTLZFixYlVPXuzvLly0NqkYhI5hVdH3580N93333MmjUrpNaIiGRP0QT+5MmTeemll2L2qZ9eRIpJwQX+Jfes5u0P/ta1fdqwo3lpzuSYc9avX8+ECROy3TQRkVAVVODHh/3uu6ayO+4cVfUiUqwK6qFtZ9i3HzqQsIRxU1OTwl5EilpBBX6nPfdeEbN9ytwVHH/88SG1RkQkN6QV+GZ2h5k1mNnr0Z9Lezm3xMzqzGxFT+cE5e/OvgjoeANV/Bo4IiLFKog+/MXu/s8pnHcTsA34QgDXTKrEoM1h2NSbGTb15pj9IiLFLitdOmY2EvgO8FAmr7PoinP6tV9EpJgEEfizzOwNM1tiZsf1cM69wK1Ae19fZmYzzazWzGobGxv71ZBpFeXc+8NzKC+LYEB5WYR7f3iOXlIiIgJYXyNXzOxF4MQkh34GrAOaAAd+CZzk7j+O+/xU4FJ3/+9m9l+AW9x9KimorKz02traVE4VERHAzDa5e2WyY3324bv75L7OiV7kQSDZE9KJwHejD3SPBr5gZr93939M5XtFRCQY6Y7SOanb5nRga/w57j7P3Ue6+2jgSmCVwl5EJPvS7cO/28zqzewN4CLgnwDMbISZrUy7dSIiEpi0hmW6+9U97N8LJIzJd/fVwOp0rikiIgNTkDNtRUQkUZ+jdMJkZo2QsP5ZT4bRMWJIeqZ7lBrdp77pHqUmjPt0irsPT3YgpwO/P8ystqehSNJB9yg1uk990z1KTa7dJ3XpiIgUCQW+iEiRKKTAfyDsBuQB3aPU6D71TfcoNTl1nwqmD19ERHpXSBW+iIj0QoEvIlIk8jbwc/VtW7kklXtkZqPM7GUze8vM3jSzm8Joa5hS/btkZn9vZtvNbIeZ3ZbtduYCM7vZzNzMhvVw/O7o36NtZvZrMyvK1w+lcJ9ONrPno/fpLTMbnY12BfHGqzDlzNu2clhf9+gwcLO7bzazzwObzOwFd38rS+3LFb3eJzMrAf4VuAR4D9hoZs8W030ys1HAt4B3ezh+Ph2r4341uutV4EKKbDmVvu5T1FLgTnd/wcw+RwrvCglC3lb4qcrW27bylbvvc/fN0d8/oeMfRr0xJtEEYIe773T3z4DHgctDblO2LabjRUY9jfRwOpZAHwwMAUqB/dlpWk7p9T6Z2VnAIHd/AcDdP3X3A9loWL4HfqBv2ypQqdwjAKL/WVkBrM9Ky3JLX/epHNjTbfs9iugfRjO7HGhw9y09nePua4GXgX3Rnxp335alJuaEVO4T8GWg2cyqol3NC6P/BZlxOd2l08fbtn5Dx1u2Ot+2tQhI9ratD9x9U/RtWwUn3XvU7Xs+BzwNzHb3jzPT2vAEdZ8KWR/36HY6uil6+/xpwJnAyOiuF8xskru/EmhDQ5bufaIjdyfRUVy9CzwBXAc8HFwre75wztLbtvoWwD3CzErpCPtl7l4VYPNyRgD3qQEY1W17ZHRfwejpHpnZOGAMsCX6DHYksNnMJrj7+91OnQ6sc/dPo5/7D+A8oKACP4D79B7wurvvjH6uGjiXLAR+3nbp6G1bfUvlHkVHUTwMbHP3e7LVtlySyn0CNgKnm9kYMxtMx9+nZ7PRvrC5e727n+Duo6P/X3oPGB8XYtBRrV5oZoOiRcSFdDwTKgr9uE8bgTIz61zR8ptAVh7+523go7dtpSKVezQRuBr4ZipDXAtUn/fJ3Q8Ds4AaOkLsSXd/M6wG5wozqzSzzgERTwHvAPXAFmCLuy8PrXE5pPt9cvc24BbgJTOrBwx4MCvt0NIKIiLFIZ8rfBER6QcFvohIkVDgi4gUCQW+iEiRUOCLiBQJBb6ISJFQ4IuIFIn/D++1aPgtSV4jAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"e_hat, f_hat = pot.predict(X_val, ncores = 2, compute_forces=False)\n",
"\n",
"# plt.scatter(f_val, f_hat)\n",
"# plt.plot(f_val, f_val, 'k-')\n",
"# print(\"MAE FORCE VECTOR: %.4f eV/A\" %(mae_force(f_val, f_hat)))\n",
"# print(\"RMSE FORCE COMPONENTS: %.4f eV/A\" %(mean_squared_error(f_val, f_hat)**0.5))\n",
"# plt.show()\n",
"e_hat, f_hat = pot.predict(X_val, ncores = 1, compute_forces=True)\n",
"\n",
"plt.scatter(f_val, f_hat)\n",
"plt.plot(f_val, f_val, 'k-')\n",
"print(\"MAE FORCE VECTOR: %.4f eV/A\" %(mae_force(f_val, f_hat)))\n",
"print(\"RMSE FORCE COMPONENTS: %.4f eV/A\" %(mean_squared_error(f_val, f_hat)**0.5))\n",
"plt.show()\n",
"\n",
"plt.scatter(e_val/nat_val, e_hat/nat_val)\n",
"plt.plot(e_val/nat_val, e_val/nat_val, 'k-')\n",
Expand Down
17 changes: 7 additions & 10 deletions raffy/linear_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def add_square_g(self, g, dg, X, compute_forces):
dg2 = np.array(dg2)
# Append squared descriptors
g = np.append(g, g2, axis=-1)

del g2

if compute_forces:
dg = np.append(dg, dg2, axis=-1)
del dg2
Expand Down Expand Up @@ -104,11 +104,10 @@ def get_g(self, X, g=None, dg=None, compute_forces=True,
g, dg = self.adjust_g(g, dg, X, compute_forces, train_pca)
return g, dg

def fit(self, X, Y=None, Y_en=None, noise=1e-8,
def fit(self, X, Y=None, Y_en=None,
g=None, dg=None, ncores=1, pca_comps=None,
compute_forces=True):

self.noise = noise
if pca_comps is not None:
self.use_pca = True
self.nc_pca = pca_comps
Expand Down Expand Up @@ -141,8 +140,8 @@ def fit(self, X, Y=None, Y_en=None, noise=1e-8,
del dg, g, Y, Y_en, X
gtg = np.einsum('na, nb -> ab', g_tot, g_tot)
# Add regularization
noise = self.noise*np.ones(len(gtg))
gtg[np.diag_indices_from(gtg)] += noise
reg = np.std(g_tot**2, axis=0) * np.eye(len(gtg))/1000
gtg += reg
# Cholesky Decomposition to find alpha
L_ = cholesky(gtg, lower=True)
# Calculate fY
Expand All @@ -153,18 +152,16 @@ def fit(self, X, Y=None, Y_en=None, noise=1e-8,
self.alpha = alpha
del gY, alpha, L_

def fit_local(self, X, Y, g, dg, noise=1e-8):
self.noise = noise
def fit_local(self, X, Y, g, dg):
dg = np.reshape(dg, (dg.shape[0]*3, dg.shape[2]))
Y = ut.reshape_forces(Y)
g_tot = -dg
Y_tot = Y
# ftf shape is (S, S)
gtg = np.einsum('na, nb -> ab', g_tot, g_tot)
# Add regularization
noise = self.noise*np.ones(len(gtg))
noise[-len(X):] = self.noise
gtg[np.diag_indices_from(gtg)] += noise
reg = np.std(g_tot**2, axis=0) * np.eye(len(gtg))/1000
gtg += reg
# Cholesky Decomposition to find alpha
L_ = cholesky(gtg, lower=True)
# Calculate fY
Expand Down

0 comments on commit 1315ef9

Please sign in to comment.