Skip to content

Commit

Permalink
Update documentations
Browse files Browse the repository at this point in the history
  • Loading branch information
actions-user committed Oct 24, 2024
1 parent 231eb75 commit a4b6fa5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
3 changes: 1 addition & 2 deletions _modules/hippynn/interfaces/ase_interface/calculator.html
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,8 @@ <h1>Source code for hippynn.interfaces.ase_interface.calculator</h1><div class="
<span class="c1"># Convert from ASE distance (angstrom) to whatever the network uses.</span>
<span class="n">positions</span> <span class="o">=</span> <span class="n">positions</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">dist_unit</span>
<span class="n">species</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">atoms</span><span class="o">.</span><span class="n">numbers</span><span class="p">,</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">cell</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">atoms</span><span class="o">.</span><span class="n">cell</span><span class="o">.</span><span class="n">array</span><span class="p">)</span> <span class="c1"># ExternalNieghbors doesn&#39;t take batch index</span>
<span class="n">cell</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">atoms</span><span class="o">.</span><span class="n">cell</span><span class="o">.</span><span class="n">array</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="c1"># Get pair first and second from neighbors list</span>

<span class="n">pair_first</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nl</span><span class="o">.</span><span class="n">nl</span><span class="o">.</span><span class="n">pair_first</span><span class="p">,</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span>
<span class="n">pair_second</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nl</span><span class="o">.</span><span class="n">nl</span><span class="o">.</span><span class="n">pair_second</span><span class="p">,</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span>
<span class="n">pair_shiftvecs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">as_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">nl</span><span class="o">.</span><span class="n">nl</span><span class="o">.</span><span class="n">offset_vec</span><span class="p">,</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span>
Expand Down
5 changes: 2 additions & 3 deletions _modules/hippynn/layers/indexers.html
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,9 @@ <h1>Source code for hippynn.layers.indexers</h1><div class="highlight"><pre>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">coordinates</span><span class="p">,</span> <span class="n">cell</span><span class="p">):</span>
<span class="n">strain</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span>
<span class="n">coordinates</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">coordinates</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">coordinates</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span>
<span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="p">)</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">coordinates</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">)</span>
<span class="n">strained_coordinates</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bmm</span><span class="p">(</span><span class="n">coordinates</span><span class="p">,</span> <span class="n">strain</span><span class="p">)</span>
<span class="k">if</span> <span class="n">cell</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">strained_cell</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mm</span><span class="p">(</span><span class="n">cell</span><span class="p">,</span> <span class="n">strain</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="n">strained_cell</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bmm</span><span class="p">(</span><span class="n">cell</span><span class="p">,</span> <span class="n">strain</span><span class="p">)</span>
<span class="k">return</span> <span class="n">strained_coordinates</span><span class="p">,</span> <span class="n">strained_cell</span><span class="p">,</span> <span class="n">strain</span></div>
</div>

Expand Down
13 changes: 10 additions & 3 deletions _modules/hippynn/layers/pairs/indexing.html
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,18 @@ <h1>Source code for hippynn.layers.pairs.indexing</h1><div class="highlight"><pr
<div class="viewcode-block" id="ExternalNeighbors.forward">
<a class="viewcode-back" href="../../../../api_documentation/hippynn.layers.pairs.indexing.html#hippynn.layers.pairs.indexing.ExternalNeighbors.forward">[docs]</a>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">coordinates</span><span class="p">,</span> <span class="n">real_atoms</span><span class="p">,</span> <span class="n">shifts</span><span class="p">,</span> <span class="n">cell</span><span class="p">,</span> <span class="n">pair_first</span><span class="p">,</span> <span class="n">pair_second</span><span class="p">):</span>
<span class="n">n_molecules</span><span class="p">,</span> <span class="n">n_atoms</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">coordinates</span><span class="o">.</span><span class="n">shape</span>
<span class="n">atom_coordinates</span> <span class="o">=</span> <span class="n">coordinates</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">n_molecules</span> <span class="o">*</span> <span class="n">n_atoms</span><span class="p">,</span> <span class="mi">3</span><span class="p">)[</span><span class="n">real_atoms</span><span class="p">]</span>
<span class="k">if</span> <span class="p">(</span><span class="n">coordinates</span><span class="o">.</span><span class="n">ndim</span> <span class="o">&gt;</span> <span class="mi">3</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">coordinates</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span> <span class="ow">and</span> <span class="n">coordinates</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;coordinates must have (n,3) or (1,n,3) but has shape </span><span class="si">{</span><span class="n">coordinates</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">coordinates</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="n">coordinates</span> <span class="o">=</span> <span class="n">coordinates</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="k">if</span> <span class="p">(</span><span class="n">cell</span><span class="o">.</span><span class="n">ndim</span> <span class="o">&gt;</span> <span class="mi">3</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">cell</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span> <span class="ow">and</span> <span class="n">cell</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;cell must have (3,3) or (1,3,3) but has shape </span><span class="si">{</span><span class="n">cell</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">cell</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="n">cell</span> <span class="o">=</span> <span class="n">cell</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

<span class="n">atom_coordinates</span> <span class="o">=</span> <span class="n">coordinates</span><span class="p">[</span><span class="n">real_atoms</span><span class="p">]</span>
<span class="n">paircoord</span> <span class="o">=</span> <span class="n">atom_coordinates</span><span class="p">[</span><span class="n">pair_second</span><span class="p">]</span> <span class="o">-</span> <span class="n">atom_coordinates</span><span class="p">[</span><span class="n">pair_first</span><span class="p">]</span> <span class="o">+</span> <span class="n">shifts</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">cell</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="o">@</span> <span class="n">cell</span>
<span class="n">distflat</span> <span class="o">=</span> <span class="n">paircoord</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

<span class="c1"># We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance. </span>
<span class="k">return</span> <span class="n">filter_pairs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hard_dist_cutoff</span><span class="p">,</span> <span class="n">distflat</span><span class="p">,</span> <span class="n">pair_first</span><span class="p">,</span> <span class="n">pair_second</span><span class="p">,</span> <span class="n">paircoord</span><span class="p">)</span></div>
</div>
Expand Down

0 comments on commit a4b6fa5

Please sign in to comment.