Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OTDD between the complete dataset of CIFAR10 with itself gives non-zero value #32

Open
Anuradha-Uggi opened this issue Mar 23, 2024 · 0 comments

Comments

@Anuradha-Uggi
Copy link

Anuradha-Uggi commented Mar 23, 2024

Hi,

Thanks for the great work. Many thanks for releasing the code to the public. I have an issue with OTDD on the CIFAR10 dataset. Below is the code.

from otdd.pytorch.datasets import load_torchvision_data
from otdd.pytorch.distance import DatasetDistance

loaders_tgt = load_torchvision_data('CIFAR10', valid_size = 0, resize = 28)[0]
loaders_src = load_torchvision_data('CIFAR10', valid_size = 0, resize = 28)[0]

print('===> Reading both datasets done')

dist = DatasetDistance(loaders_src['train'], loaders_src['train'],
method = 'precomputed_labeldist',
inner_ot_method = 'exact',
inner_ot_debiased = True,
debiased_loss = True,
p = 2, entreg = 1e-1,
device='cuda')

d = dist.distance()
print(f'OTDD-Exact-CompleteData(CIFAR10 Img, CIFAR10 Img)={d:8.2f}')

  1. No subset random sampling is happening. The complete dataset is read and loaded only once since I am feeding src data in place of tgt data. It should give a zero distance, but below is the output.

$OTDD-Exact-CompleteData(CIFAR10 Img, CIFAR10 Img)= 723.36

Surprisingly, it gives 0 distance as expected when computed only on 2000 samples, the same as the default.

  1. In place of CIFAR10, when used MNIST/FashionMNIST complete dataset, the below error is thrown,
    $Distance computation failed. Aborting.

The exact problem is as below
$geomloss/sinkhorn_samples.py", line 327, in lse_genred
"( B - (P * " + cost + " ) )",
TypeError: can only concatenate str (not "function") to str

But the same code works when given 2000 samples, as in the default code.

Please help understand why this could be the case. Especially the CIFAR10 issue.

Thanks!

Best,
Anuradha

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant