Skip to content

Commit

Permalink
Added faster version of mpml
Browse files Browse the repository at this point in the history
  • Loading branch information
soham0209 committed Apr 17, 2023
1 parent 0727777 commit 41b8690
Show file tree
Hide file tree
Showing 100 changed files with 3,526 additions and 1,360 deletions.
17 changes: 9 additions & 8 deletions experiments.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions gril/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .gril import *
Binary file added gril/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added gril/__pycache__/gril.cpython-39.pyc
Binary file not shown.
Binary file added gril/__pycache__/zz.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
15 changes: 15 additions & 0 deletions gril/build/temp.linux-x86_64-cpython-39/.ninja_log
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# ninja log v5
0 19238 1674245376000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a
274 16567 1675726632000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a
269 16888 1675728138000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a
26 24273 1675902702086332969 /u/scratch1/ssamaga/mpml_altered_bdry/zigzag/build/temp.linux-x86_64-cpython-39/u/scratch1/ssamaga/mpml_altered_bdry/zigzag/multipers.o 72e776d2b1d272f3
5160 33366 1675906707159817794 /u/scratch1/ssamaga/mpml_altered_bdry/zigzag/build/temp.linux-x86_64-cpython-39/u/scratch1/ssamaga/mpml_altered_bdry/zigzag/multipers.o 72e776d2b1d272f3
0 17526 1675907214000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a
31 17395 1675907889000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a
257 17692 1675908048000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a
308 16897 1675909436000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a
231 16397 1675909668000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a
257 16577 1675911626000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a
290 17660 1675912935000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a
44 17018 1679793102000000000 /scratch/bell/mukher26/mpml_graph_static/zigzag/build/temp.linux-x86_64-cpython-39/scratch/bell/mukher26/mpml_graph_static/zigzag/multipers.o 9a2b7b576f7f16a
13 33166 1681688576281989022 /scratch1/mukher26/mpml_graph_repo/gril/build/temp.linux-x86_64-cpython-39/scratch1/mukher26/mpml_graph_repo/gril/multipers.o 9d57bdffb2f12246
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ rule compile



build /scratch1/mukher26/mpml_graph_repo/zigzag/build/temp.linux-x86_64-cpython-39/scratch1/mukher26/mpml_graph_repo/zigzag/./multipers.o: compile /scratch1/mukher26/mpml_graph_repo/zigzag/multipers.cpp
build /scratch1/mukher26/mpml_graph_repo/gril/build/temp.linux-x86_64-cpython-39/scratch1/mukher26/mpml_graph_repo/gril/./multipers.o: compile /scratch1/mukher26/mpml_graph_repo/gril/multipers.cpp



Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added gril/dist/mpml-0.0.0-py3.9-linux-x86_64.egg
Binary file not shown.
12 changes: 9 additions & 3 deletions zigzag/zz.py → gril/gril.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


class MultiPers:
def __init__(self, hom_rank: int, l: int, res: float, ranks: List[int]):
def __init__(self, hom_rank: int, l: int, res: float, step: int, ranks: List[int]):
# try:
# __M = load(
# 'zigzag',
Expand All @@ -45,15 +45,21 @@ def __init__(self, hom_rank: int, l: int, res: float, ranks: List[int]):

# except Exception as ex:
# print("Error was {}".format(ex))

self.mpl = mpml.Multipers(hom_rank, l, res, ranks)
# const int hom_rank, const int l, double res, int step, const std::vector<int> ranks
self.mpl = mpml.Multipers(hom_rank, l, res, step, ranks)


def compute_landscape(self, pts: List[Tuple[int]], batch: List[Tuple[Tensor, List[List[int]]]]):
return self.mpl.compute_landscape(pts, batch)

def set_max_jobs(self, njobs: int):
self.mpl.set_max_jobs(njobs)

def set_hom_rank(self, hom_rank: int):
self.mpl.set_hom_rank(hom_rank)

def refresh_rank_info(self):
self.mpl.refresh_rank_info()


"""
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
setup.py
/scratch1/mukher26/mpml_graph_repo/zigzag/./multipers.cpp
/scratch1/mukher26/mpml_graph_repo/gril/./multipers.cpp
mpml.egg-info/PKG-INFO
mpml.egg-info/SOURCES.txt
mpml.egg-info/dependency_links.txt
Expand Down
File renamed without changes.
File renamed without changes.
254 changes: 173 additions & 81 deletions zigzag/multipers.cpp → gril/multipers.cpp

Large diffs are not rendered by default.

120 changes: 120 additions & 0 deletions gril/multipers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#ifndef MULTIPERS_H
#define MULTIPERS_H
#include <torch/extension.h>
#include <iostream>
#include <vector>
#include <algorithm>
#include <limits>
#include <tuple>
#include <future>
#include <map>
#include "utils.h"

#include "./phat/compute_persistence_pairs.h"

using torch::Tensor;
using namespace torch::indexing;
typedef std::pair<int, int> Point;

class Multipers{
private:
int hom_rank;
std::vector<int> ranks;
double step, ll_x, ll_y, res;
int px, py;
int l;
std::vector<std::map<int, int>*> rank_info_h0;
std::vector<std::map<int, int>*> rank_info_h1;
int num_points_guess;
void set_ranks(std::vector<int> ranks_){
this->ranks.insert(this->ranks.begin(), ranks_.begin(), ranks_.end());
}
void set_step(double step){
this->step = step;
}
void set_res(double res){
this->res = res;
}
void set_l_for_worm(int l){
this->l = l;
}
Tensor compute_l_worm(const int d);
std::vector<std::tuple<bool, Integer>> compute_filtration_along_boundary_cap(const Tensor& grid_pts_along_boundary_t,
const Tensor& f,
const Tensor& f_x_sorted,
const Tensor& f_y_sorted,
const Tensor& f_x_sorted_id,
const Tensor& f_y_sorted_id,
int &manual_birth_pts,
int &manual_death_pts);

void zigzag_pairs(std::vector<std::tuple<bool, Integer>> &simplices_birth_death,
const vector<Simplex> &simplices,
const int manual_birth_pts,
const int manual_death_pts,
std::vector<int> &num_full_bars);

void num_full_bars_for_specific_d(const Tensor& filtration,
const Tensor& f_x_sorted,
const Tensor& f_y_sorted,
const Tensor& f_x_sorted_id,
const Tensor& f_y_sorted_id,
const vector<Simplex>& simplices,
const Point& p,
int d,
std::vector<int> &num_full_bars);

Tensor find_maximal_worm_for_rank_k(const Tensor &filtration,
const Tensor& f_x_sorted,
const Tensor& f_y_sorted,
const Tensor& f_x_sorted_id,
const Tensor& f_y_sorted_id,
const vector<Simplex> &simplices,
const Point &p,
const int rank,
std::vector<std::map<int, int>*> rank_info);

void set_grid_resolution_and_lower_left_corner(const Tensor& filtration);



public:
int max_threads;
Multipers(const int hom_rank, const int l, double res, int step, const std::vector<int> ranks){
set_hom_rank(hom_rank);
set_l_for_worm(l);
// set_division_along_axes(px, py);
set_step(step);
set_res(res);
set_ranks(ranks);
this->max_threads = 1;
int num_cp = (int)(1.0 / res);
int num_div = num_cp / step;
this->num_points_guess = (num_div * num_div);

for(auto i = 0; i < num_points_guess; i++){
this->rank_info_h0.push_back(new std::map<int, int>());
this->rank_info_h1.push_back(new std::map<int, int>());
}
}
void set_max_jobs(int max_jobs){
this->max_threads = max_jobs;
}
std::vector<Tensor> compute_landscape(const std::vector<Point>& pts, const std::vector<std::tuple<Tensor, vector<Simplex>>> &batch);

void set_hom_rank(int hom_rank){
this->hom_rank = hom_rank;
}
void refresh_rank_info(){
for(auto i = 0; i < num_points_guess; i++){
this->rank_info_h0[i] = new std::map<int, int>();
this->rank_info_h1[i] = new std::map<int, int>();

}
}



};

#endif
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
6 changes: 6 additions & 0 deletions gril/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from torch.utils.cpp_extension import load
from torch import Tensor

test = load('test', sources=['test.cpp'], extra_cflags= ['-fopenmp'], verbose=True)

test.test(32)
9 changes: 9 additions & 0 deletions gril/torchph/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
develop.py
develop_2.py
.idea
.vscode
__pycache__
*.pyc
.cache
.pytest_cache
pershom_dev/extensions_sandbox
Empty file.
2 changes: 2 additions & 0 deletions gril/torchph/chofer_torchex/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .slayer import SLayerExponential, SLayerRational, SLayerRationalHat
from .modules import *
49 changes: 49 additions & 0 deletions gril/torchph/chofer_torchex/nn/fuctional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch


def histogram_intersection_loss(input: torch.Tensor,
target: torch.Tensor,
size_average: bool=True,
reduce: bool=True,
symetric_version: bool=True)->torch.Tensor:
r"""
This loss function is based on the `Histogram Intersection` score.
The output is the *negative* Histogram Intersection Score.
Args:
input (Tensor): :math:`(N, B)` where `N = batch size` and `B = number of classes`
target (Tensor): :math:`(N, B)` where `N = batch size` and `B = number of classes`
size_average (bool, optional): By default, the losses are averaged
over observations for each minibatch. However, if the field
:attr:`size_average` is set to ``False``, the losses are instead summed
for each minibatch. Ignored if :attr:`reduce` is ``False``. Default: ``True``
reduce (bool, optional):
symetric_version (bool, optional): By default, the symetric version of histogram intersection
is used. If false the asymetric version is used. Default: ``True``
Returns: Tensor.
"""
assert input.size() == target.size(), \
"input.size() != target.size(): {} != {}!".format(input.size(), target.size())
assert input.dim() == target.dim() == 2, \
"input, target must be 2 dimensional. Got dim {} resp. {}".format(input.dim(), target.dim())

minima = input.min(target)
summed_minima = minima.sum(dim=1)

if symetric_version:
normalization_factor = (input.sum(dim=1)).max(target.sum(dim=1))
else:
normalization_factor = target.sum(dim=1)

loss = summed_minima / normalization_factor

if reduce:
loss = sum(loss)

if size_average:
loss = loss / input.size(0)

return -loss
45 changes: 45 additions & 0 deletions gril/torchph/chofer_torchex/nn/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import torch.nn as nn


class LinearView(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.view(x.size()[0], -1)


class Apply(nn.Module):
def __init__(self, function):
super().__init__()
self.function = function

def forward(self, *args, **kwargs):
return self.function(*args, **kwargs)


class IndependentBranchesLinear(nn.Linear):
def __init__(self, in_features, out_features_branch, n_branches, bias=True):
assert in_features % n_branches == 0
in_features_branch = int(in_features/n_branches)
super().__init__(in_features, out_features_branch*n_branches, bias)

mask = torch.zeros_like(self.weight)
for i in range(n_branches):
mask[i*out_features_branch:(i+1)*out_features_branch,
i*in_features_branch:(i+1)*in_features_branch] = 1

self.register_buffer('mask', mask)

def forward(self, inputs):
return torch.nn.functional.linear(inputs, self.weight * self.mask, self.bias)


class View(nn.Module):
def __init__(self, view_args):
super().__init__()
self.view_args = view_args

def forward(self, input):
return input.view(*self.view_args)
Loading

0 comments on commit 41b8690

Please sign in to comment.