Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch hard and balanced batch #45

Open
batrlatom opened this issue Aug 5, 2019 · 1 comment
Open

Batch hard and balanced batch #45

batrlatom opened this issue Aug 5, 2019 · 1 comment

Comments

@batrlatom
Copy link

Hello. I am unable to make work hard triplet mining and balanced batches. I think that we had a discussion about it here, but so far I think embedding are always collapsing into one point. I tried many combinations of "margins", "num_classes_per_batch" , "num_images_per_class". But nothing seems to work. Could you please take a look at the code if there is some obvious problem? Noting that with batch_all strategy, it works well.
Thanks,
Tom


def train_input_fn(data_dir, params):
    data_root = pathlib.Path(data_dir)
    all_image_paths = list(data_root.glob('**/*.jpg'))
    all_directories = {'/'.join(str(i).split("/")[:-1]) for i in all_image_paths}
    print("-----")
    print("num of labels: ")
    print(len(all_directories))
    print("-----")
    labels_index = list(i.split("/")[-1] for i in  all_directories)

    # Create the list of datasets creating filenames
    datasets = [tf.data.Dataset.list_files("{}/*.jpg".format(image_dir), shuffle=False) for image_dir in all_directories]

    num_labels = len(all_directories)
    print(datasets)
    num_classes_per_batch = params.num_classes_per_batch
    num_images_per_class = params.num_images_per_class

    def get_label_index(s):
        return labels_index.index(s.numpy().decode("utf-8").split("/")[-2])

    def preprocess_image(image):   
      image = tf.cast(image, tf.float32)
      image = tf.math.divide(image, 255.0)     
      return image

    def load_and_preprocess_image(path):
        image = tf.read_file(path)
        return tf.py_function(preprocess_image, [image], tf.float32), tf.py_function(get_label_index, [path], tf.int64)

    def generator():
        while True:
            # Sample the labels that will compose the batch
            labels = np.random.choice(range(num_labels),
                                      num_classes_per_batch,
                                      replace=False)
            for label in labels:
                for _ in range(num_images_per_class):
                    yield label

    choice_dataset = tf.data.Dataset.from_generator(generator, tf.int64)
    dataset = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)


    dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    batch_size = num_classes_per_batch * num_images_per_class
    print("----------------------")
    print(batch_size)
    print("----------------------")
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(params.num_epochs)
    dataset = dataset.prefetch(1)

    print(dataset)
    return dataset

@batrlatom batrlatom changed the title Batch hard and balabced batch Batch hard and balanced batch Aug 5, 2019
@omoindrot
Copy link
Owner

Don't see anything wrong.

If the batch all loss works, and the batch hard triplet loss does not, this might indicate that your dataset is a bit noisy so hard triplets are mislabeled.

You can also train first with batch all, then finetune at the end with batch hard.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants