-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
100 changed files
with
3,526 additions
and
1,360 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .gril import * |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+2.1 MB
gril/build/lib.linux-x86_64-cpython-39/mpml.cpython-39-x86_64-linux-gnu.so
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# ninja log v5 | ||
0 19238 1674245376000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a | ||
274 16567 1675726632000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a | ||
269 16888 1675728138000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a | ||
26 24273 1675902702086332969 /u/scratch1/ssamaga/mpml_altered_bdry/zigzag/build/temp.linux-x86_64-cpython-39/u/scratch1/ssamaga/mpml_altered_bdry/zigzag/multipers.o 72e776d2b1d272f3 | ||
5160 33366 1675906707159817794 /u/scratch1/ssamaga/mpml_altered_bdry/zigzag/build/temp.linux-x86_64-cpython-39/u/scratch1/ssamaga/mpml_altered_bdry/zigzag/multipers.o 72e776d2b1d272f3 | ||
0 17526 1675907214000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a | ||
31 17395 1675907889000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a | ||
257 17692 1675908048000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a | ||
308 16897 1675909436000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a | ||
231 16397 1675909668000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a | ||
257 16577 1675911626000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a | ||
290 17660 1675912935000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a | ||
44 17018 1679793102000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a | ||
13 33166 1681688576281989022 /scratch1/mukher26/mpml_graph_repo/gril/build/temp.linux-x86_64-cpython-39/scratch1/mukher26/mpml_graph_repo/gril/multipers.o 9d57bdffb2f12246 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file added
BIN
+3.47 MB
...d/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o
Binary file not shown.
Binary file added
BIN
+3.49 MB
gril/build/temp.linux-x86_64-cpython-39/scratch1/mukher26/mpml_graph_repo/gril/multipers.o
Binary file not shown.
Binary file added
BIN
+3.49 MB
...uild/temp.linux-x86_64-cpython-39/u/scratch1/ssamaga/mpml_altered_bdry/zigzag/multipers.o
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
2 changes: 1 addition & 1 deletion
2
zigzag/mpml.egg-info/SOURCES.txt → gril/mpml.egg-info/SOURCES.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#ifndef MULTIPERS_H | ||
#define MULTIPERS_H | ||
#include <torch/extension.h> | ||
#include <iostream> | ||
#include <vector> | ||
#include <algorithm> | ||
#include <limits> | ||
#include <tuple> | ||
#include <future> | ||
#include <map> | ||
#include "utils.h" | ||
|
||
#include "./phat/compute_persistence_pairs.h" | ||
|
||
using torch::Tensor; | ||
using namespace torch::indexing; | ||
typedef std::pair<int, int> Point; | ||
|
||
class Multipers{ | ||
private: | ||
int hom_rank; | ||
std::vector<int> ranks; | ||
double step, ll_x, ll_y, res; | ||
int px, py; | ||
int l; | ||
std::vector<std::map<int, int>*> rank_info_h0; | ||
std::vector<std::map<int, int>*> rank_info_h1; | ||
int num_points_guess; | ||
void set_ranks(std::vector<int> ranks_){ | ||
this->ranks.insert(this->ranks.begin(), ranks_.begin(), ranks_.end()); | ||
} | ||
void set_step(double step){ | ||
this->step = step; | ||
} | ||
void set_res(double res){ | ||
this->res = res; | ||
} | ||
void set_l_for_worm(int l){ | ||
this->l = l; | ||
} | ||
Tensor compute_l_worm(const int d); | ||
std::vector<std::tuple<bool, Integer>> compute_filtration_along_boundary_cap(const Tensor& grid_pts_along_boundary_t, | ||
const Tensor& f, | ||
const Tensor& f_x_sorted, | ||
const Tensor& f_y_sorted, | ||
const Tensor& f_x_sorted_id, | ||
const Tensor& f_y_sorted_id, | ||
int &manual_birth_pts, | ||
int &manual_death_pts); | ||
|
||
void zigzag_pairs(std::vector<std::tuple<bool, Integer>> &simplices_birth_death, | ||
const vector<Simplex> &simplices, | ||
const int manual_birth_pts, | ||
const int manual_death_pts, | ||
std::vector<int> &num_full_bars); | ||
|
||
void num_full_bars_for_specific_d(const Tensor& filtration, | ||
const Tensor& f_x_sorted, | ||
const Tensor& f_y_sorted, | ||
const Tensor& f_x_sorted_id, | ||
const Tensor& f_y_sorted_id, | ||
const vector<Simplex>& simplices, | ||
const Point& p, | ||
int d, | ||
std::vector<int> &num_full_bars); | ||
|
||
Tensor find_maximal_worm_for_rank_k(const Tensor &filtration, | ||
const Tensor& f_x_sorted, | ||
const Tensor& f_y_sorted, | ||
const Tensor& f_x_sorted_id, | ||
const Tensor& f_y_sorted_id, | ||
const vector<Simplex> &simplices, | ||
const Point &p, | ||
const int rank, | ||
std::vector<std::map<int, int>*> rank_info); | ||
|
||
void set_grid_resolution_and_lower_left_corner(const Tensor& filtration); | ||
|
||
|
||
|
||
public: | ||
int max_threads; | ||
Multipers(const int hom_rank, const int l, double res, int step, const std::vector<int> ranks){ | ||
set_hom_rank(hom_rank); | ||
set_l_for_worm(l); | ||
// set_division_along_axes(px, py); | ||
set_step(step); | ||
set_res(res); | ||
set_ranks(ranks); | ||
this->max_threads = 1; | ||
int num_cp = (int)(1.0 / res); | ||
int num_div = num_cp / step; | ||
this->num_points_guess = (num_div * num_div); | ||
|
||
for(auto i = 0; i < num_points_guess; i++){ | ||
this->rank_info_h0.push_back(new std::map<int, int>()); | ||
this->rank_info_h1.push_back(new std::map<int, int>()); | ||
} | ||
} | ||
void set_max_jobs(int max_jobs){ | ||
this->max_threads = max_jobs; | ||
} | ||
std::vector<Tensor> compute_landscape(const std::vector<Point>& pts, const std::vector<std::tuple<Tensor, vector<Simplex>>> &batch); | ||
|
||
void set_hom_rank(int hom_rank){ | ||
this->hom_rank = hom_rank; | ||
} | ||
void refresh_rank_info(){ | ||
for(auto i = 0; i < num_points_guess; i++){ | ||
this->rank_info_h0[i] = new std::map<int, int>(); | ||
this->rank_info_h1[i] = new std::map<int, int>(); | ||
|
||
} | ||
} | ||
|
||
|
||
|
||
}; | ||
|
||
#endif |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from torch.utils.cpp_extension import load | ||
from torch import Tensor | ||
|
||
test = load('test', sources=['test.cpp'], extra_cflags= ['-fopenmp'], verbose=True) | ||
|
||
test.test(32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
develop.py | ||
develop_2.py | ||
.idea | ||
.vscode | ||
__pycache__ | ||
*.pyc | ||
.cache | ||
.pytest_cache | ||
pershom_dev/extensions_sandbox |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .slayer import SLayerExponential, SLayerRational, SLayerRationalHat | ||
from .modules import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import torch | ||
|
||
|
||
def histogram_intersection_loss(input: torch.Tensor, | ||
target: torch.Tensor, | ||
size_average: bool=True, | ||
reduce: bool=True, | ||
symetric_version: bool=True)->torch.Tensor: | ||
r""" | ||
This loss function is based on the `Histogram Intersection` score. | ||
The output is the *negative* Histogram Intersection Score. | ||
Args: | ||
input (Tensor): :math:`(N, B)` where `N = batch size` and `B = number of classes` | ||
target (Tensor): :math:`(N, B)` where `N = batch size` and `B = number of classes` | ||
size_average (bool, optional): By default, the losses are averaged | ||
over observations for each minibatch. However, if the field | ||
:attr:`size_average` is set to ``False``, the losses are instead summed | ||
for each minibatch. Ignored if :attr:`reduce` is ``False``. Default: ``True`` | ||
reduce (bool, optional): | ||
symetric_version (bool, optional): By default, the symetric version of histogram intersection | ||
is used. If false the asymetric version is used. Default: ``True`` | ||
Returns: Tensor. | ||
""" | ||
assert input.size() == target.size(), \ | ||
"input.size() != target.size(): {} != {}!".format(input.size(), target.size()) | ||
assert input.dim() == target.dim() == 2, \ | ||
"input, target must be 2 dimensional. Got dim {} resp. {}".format(input.dim(), target.dim()) | ||
|
||
minima = input.min(target) | ||
summed_minima = minima.sum(dim=1) | ||
|
||
if symetric_version: | ||
normalization_factor = (input.sum(dim=1)).max(target.sum(dim=1)) | ||
else: | ||
normalization_factor = target.sum(dim=1) | ||
|
||
loss = summed_minima / normalization_factor | ||
|
||
if reduce: | ||
loss = sum(loss) | ||
|
||
if size_average: | ||
loss = loss / input.size(0) | ||
|
||
return -loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class LinearView(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x): | ||
return x.view(x.size()[0], -1) | ||
|
||
|
||
class Apply(nn.Module): | ||
def __init__(self, function): | ||
super().__init__() | ||
self.function = function | ||
|
||
def forward(self, *args, **kwargs): | ||
return self.function(*args, **kwargs) | ||
|
||
|
||
class IndependentBranchesLinear(nn.Linear): | ||
def __init__(self, in_features, out_features_branch, n_branches, bias=True): | ||
assert in_features % n_branches == 0 | ||
in_features_branch = int(in_features/n_branches) | ||
super().__init__(in_features, out_features_branch*n_branches, bias) | ||
|
||
mask = torch.zeros_like(self.weight) | ||
for i in range(n_branches): | ||
mask[i*out_features_branch:(i+1)*out_features_branch, | ||
i*in_features_branch:(i+1)*in_features_branch] = 1 | ||
|
||
self.register_buffer('mask', mask) | ||
|
||
def forward(self, inputs): | ||
return torch.nn.functional.linear(inputs, self.weight * self.mask, self.bias) | ||
|
||
|
||
class View(nn.Module): | ||
def __init__(self, view_args): | ||
super().__init__() | ||
self.view_args = view_args | ||
|
||
def forward(self, input): | ||
return input.view(*self.view_args) |
Oops, something went wrong.