-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from iKintosh/dev
Release 0.2.0
- Loading branch information
Showing
22 changed files
with
535 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
/config.local | ||
/tmp | ||
/cache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[core] | ||
remote = dogs_cats | ||
['remote "dogs_cats"'] | ||
url = gdrive://16rugGb6LDbERDMD97_xNNbb8y6APci8r |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
{ | ||
"$schema": "https://vega.github.io/schema/vega-lite/v4.json", | ||
"data": { | ||
"values": "<DVC_METRIC_DATA>" | ||
}, | ||
"title": "<DVC_METRIC_TITLE>", | ||
"mark": "rect", | ||
"encoding": { | ||
"x": { | ||
"field": "<DVC_METRIC_X>", | ||
"type": "nominal", | ||
"sort": "ascending", | ||
"title": "<DVC_METRIC_X_LABEL>" | ||
}, | ||
"y": { | ||
"field": "<DVC_METRIC_Y>", | ||
"type": "nominal", | ||
"sort": "ascending", | ||
"title": "<DVC_METRIC_Y_LABEL>" | ||
}, | ||
"color": { | ||
"aggregate": "count", | ||
"type": "quantitative" | ||
}, | ||
"facet": { | ||
"field": "rev", | ||
"type": "nominal" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
{ | ||
"$schema": "https://vega.github.io/schema/vega-lite/v4.json", | ||
"data": { | ||
"values": "<DVC_METRIC_DATA>" | ||
}, | ||
"title": "<DVC_METRIC_TITLE>", | ||
"mark": { | ||
"type": "line" | ||
}, | ||
"encoding": { | ||
"x": { | ||
"field": "<DVC_METRIC_X>", | ||
"type": "quantitative", | ||
"title": "<DVC_METRIC_X_LABEL>" | ||
}, | ||
"y": { | ||
"field": "<DVC_METRIC_Y>", | ||
"type": "quantitative", | ||
"title": "<DVC_METRIC_Y_LABEL>", | ||
"scale": { | ||
"zero": false | ||
} | ||
}, | ||
"color": { | ||
"field": "rev", | ||
"type": "nominal" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
{ | ||
"$schema": "https://vega.github.io/schema/vega-lite/v4.json", | ||
"data": { | ||
"values": "<DVC_METRIC_DATA>" | ||
}, | ||
"title": "<DVC_METRIC_TITLE>", | ||
"mark": "point", | ||
"encoding": { | ||
"x": { | ||
"field": "<DVC_METRIC_X>", | ||
"type": "quantitative", | ||
"title": "<DVC_METRIC_X_LABEL>" | ||
}, | ||
"y": { | ||
"field": "<DVC_METRIC_Y>", | ||
"type": "quantitative", | ||
"title": "<DVC_METRIC_Y_LABEL>", | ||
"scale": { | ||
"zero": false | ||
} | ||
}, | ||
"color": { | ||
"field": "rev", | ||
"type": "nominal" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
{ | ||
"$schema": "https://vega.github.io/schema/vega-lite/v4.json", | ||
"data": { | ||
"values": "<DVC_METRIC_DATA>" | ||
}, | ||
"title": "<DVC_METRIC_TITLE>", | ||
"mark": { | ||
"type": "line" | ||
}, | ||
"encoding": { | ||
"x": { | ||
"field": "<DVC_METRIC_X>", | ||
"type": "quantitative", | ||
"title": "<DVC_METRIC_X_LABEL>" | ||
}, | ||
"y": { | ||
"field": "<DVC_METRIC_Y>", | ||
"type": "quantitative", | ||
"title": "<DVC_METRIC_Y_LABEL>", | ||
"scale": { | ||
"zero": false | ||
} | ||
}, | ||
"color": { | ||
"field": "rev", | ||
"type": "nominal" | ||
} | ||
}, | ||
"transform": [ | ||
{ | ||
"loess": "<DVC_METRIC_Y>", | ||
"on": "<DVC_METRIC_X>", | ||
"groupby": [ | ||
"rev" | ||
], | ||
"bandwidth": 0.3 | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Add patterns of files dvc should ignore, which could improve | ||
# the performance. Learn more at | ||
# https://dvc.org/doc/user-guide/dvcignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,6 @@ | |
/dist/ | ||
/GaborNet.egg-info/ | ||
/.coverage | ||
/.ipynb_checkpoints/ | ||
/metrics.json | ||
/poetry.lock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import math | ||
from typing import Any | ||
|
||
import torch | ||
from torch.nn import Parameter | ||
from torch.nn.modules import Module, Conv2d | ||
|
||
|
||
class GaborConv2d(Module): | ||
def __init__( | ||
self, | ||
in_channels, | ||
out_channels, | ||
kernel_size, | ||
stride=1, | ||
padding=0, | ||
dilation=1, | ||
groups=1, | ||
bias=False, | ||
padding_mode="zeros", | ||
): | ||
super().__init__() | ||
|
||
self.is_calculated = False | ||
|
||
self.conv_layer = Conv2d( | ||
in_channels, | ||
out_channels, | ||
kernel_size, | ||
stride, | ||
padding, | ||
dilation, | ||
groups, | ||
bias, | ||
padding_mode, | ||
) | ||
self.kernel_size = self.conv_layer.kernel_size | ||
|
||
# small addition to avoid division by zero | ||
self.delta = 1e-3 | ||
|
||
# freq, theta, sigma are set up according to S. Meshgini, | ||
# A. Aghagolzadeh and H. Seyedarabi, "Face recognition using | ||
# Gabor filter bank, kernel principal component analysis | ||
# and support vector machine" | ||
self.freq = Parameter( | ||
(math.pi / 2) | ||
* math.sqrt(2) | ||
** (-torch.randint(0, 5, (out_channels, in_channels))).type(torch.Tensor), | ||
requires_grad=True, | ||
) | ||
self.theta = Parameter( | ||
(math.pi / 8) | ||
* torch.randint(0, 8, (out_channels, in_channels)).type(torch.Tensor), | ||
requires_grad=True, | ||
) | ||
self.sigma = Parameter(math.pi / self.freq, requires_grad=True) | ||
self.psi = Parameter( | ||
math.pi * torch.rand(out_channels, in_channels), requires_grad=True | ||
) | ||
|
||
self.x0 = Parameter( | ||
torch.ceil(torch.Tensor([self.kernel_size[0] / 2]))[0], requires_grad=False | ||
) | ||
self.y0 = Parameter( | ||
torch.ceil(torch.Tensor([self.kernel_size[1] / 2]))[0], requires_grad=False | ||
) | ||
|
||
self.y, self.x = torch.meshgrid( | ||
[ | ||
torch.linspace(-self.x0 + 1, self.x0 + 0, self.kernel_size[0]), | ||
torch.linspace(-self.y0 + 1, self.y0 + 0, self.kernel_size[1]), | ||
] | ||
) | ||
self.y = Parameter(self.y) | ||
self.x = Parameter(self.x) | ||
|
||
self.weight = Parameter( | ||
torch.empty(self.conv_layer.weight.shape, requires_grad=True), | ||
requires_grad=True, | ||
) | ||
|
||
self.register_parameter("freq", self.freq) | ||
self.register_parameter("theta", self.theta) | ||
self.register_parameter("sigma", self.sigma) | ||
self.register_parameter("psi", self.psi) | ||
self.register_parameter("x_shape", self.x0) | ||
self.register_parameter("y_shape", self.y0) | ||
self.register_parameter("y_grid", self.y) | ||
self.register_parameter("x_grid", self.x) | ||
self.register_parameter("weight", self.weight) | ||
|
||
def forward(self, input_tensor): | ||
if self.training: | ||
self.calculate_weights() | ||
self.is_calculated = False | ||
if not self.training: | ||
if not self.is_calculated: | ||
self.calculate_weights() | ||
self.is_calculated = True | ||
return self.conv_layer(input_tensor) | ||
|
||
def calculate_weights(self): | ||
for i in range(self.conv_layer.out_channels): | ||
for j in range(self.conv_layer.in_channels): | ||
sigma = self.sigma[i, j].expand_as(self.y) | ||
freq = self.freq[i, j].expand_as(self.y) | ||
theta = self.theta[i, j].expand_as(self.y) | ||
psi = self.psi[i, j].expand_as(self.y) | ||
|
||
rotx = self.x * torch.cos(theta) + self.y * torch.sin(theta) | ||
roty = -self.x * torch.sin(theta) + self.y * torch.cos(theta) | ||
|
||
g = torch.exp( | ||
-0.5 * ((rotx ** 2 + roty ** 2) / (sigma + self.delta) ** 2) | ||
) | ||
g = g * torch.cos(freq * rotx + psi) | ||
g = g / (2 * math.pi * sigma ** 2) | ||
self.conv_layer.weight.data[i, j] = g | ||
|
||
def _forward_unimplemented(self, *inputs: Any): | ||
""" | ||
code checkers makes implement this method, | ||
looks like error in PyTorch | ||
""" | ||
raise NotImplementedError |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .GaborLayers import GaborConv2d | ||
from .GaborLayer import GaborConv2d | ||
|
||
__version__ = "0.2.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
outs: | ||
- md5: 45680c6dfa0b2d63cb813bd33995301c.dir | ||
path: data |
Oops, something went wrong.