diff --git a/miprometheus/helpers/problem_initializer.py b/miprometheus/helpers/problem_initializer.py index 0508e7d0..f4295812 100644 --- a/miprometheus/helpers/problem_initializer.py +++ b/miprometheus/helpers/problem_initializer.py @@ -28,9 +28,6 @@ import os import argparse -import urllib.request -import sys -import time #import json from miprometheus.problems.problem_factory import ProblemFactory diff --git a/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/generate_dataset.py b/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/generate_dataset.py index da4ab038..b621b698 100644 --- a/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/generate_dataset.py +++ b/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/generate_dataset.py @@ -16,19 +16,15 @@ """Script for generating a COG dataset""" import errno -import functools import gzip -import itertools import json import multiprocessing import os import random import shutil -import traceback import numpy as np -from miprometheus.problems.seq_to_seq.vqa.cog.cog_utils import stim_generator as sg import miprometheus.problems.seq_to_seq.vqa.cog.cog_utils.task_bank as task_bank diff --git a/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/json_to_img.py b/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/json_to_img.py index 77ba1b02..0b4c4606 100644 --- a/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/json_to_img.py +++ b/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/json_to_img.py @@ -67,7 +67,7 @@ def set_outputs_from_tasks(n_epoch, tasks, objsets, for epoch_now in range(n_epoch): for task, objset in zip(tasks, objsets): target = task(objset, epoch_now) - if target is const.INVALID: + if target == const.INVALID: # For invalid target, no loss is used. Everything remains zero. pass elif isinstance(target, sg.Loc): diff --git a/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/stim_generator.py b/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/stim_generator.py index 6b7b196b..6e8e54f3 100644 --- a/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/stim_generator.py +++ b/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/stim_generator.py @@ -829,7 +829,7 @@ def render_target(movie, target): frame[:] = np.array(image)[:] else: - if target_now is const.INVALID: + if target_now == const.INVALID: string = 'invalid' elif isinstance(target_now, bool): string = 'true' if target_now else 'false' @@ -976,7 +976,7 @@ def another_attr(attr): return another_shape(attr) elif isinstance(attr, Space): return another_loc(attr) - elif attr is const.INVALID: + elif attr == const.INVALID: return attr else: raise TypeError( diff --git a/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/task_generator.py b/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/task_generator.py index 94e33811..f5f88686 100644 --- a/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/task_generator.py +++ b/miprometheus/problems/seq_to_seq/vqa/cog/cog_utils/task_generator.py @@ -19,7 +19,6 @@ """ from collections import defaultdict -import copy import random from miprometheus.problems.seq_to_seq.vqa.cog.cog_utils import constants as const @@ -233,7 +232,7 @@ def __str__(self): def __call__(self, objset, epoch_now): del objset - del epoch_now + #del epoch_now def set_child(self, child): """Set operators as children.""" @@ -383,7 +382,7 @@ def get_expected_input(self, should_be, objset, epoch_now): a = getattr(self, attr_type) attr = a(objset, epoch_now) # If the input is successfully evaluated - if attr is not const.INVALID and attr.has_value: + if attr != const.INVALID and attr.has_value: if attr_type == 'loc': attr = attr.get_space_to(self.space_type) attr_new_object.append(attr) @@ -436,7 +435,7 @@ def get_expected_input(self, should_be, objset, epoch_now): a = getattr(self, attr_type) attr = a(objset, epoch_now) if isinstance(a, Operator): - if attr is const.INVALID: + if attr == const.INVALID: # Can not be evaluated yet, then randomly choose one attr = sg.random_attr(attr_type) attr_expected_in.append(attr) @@ -511,7 +510,7 @@ def __call__(self, objset, epoch_now): else: objs = self.objs - if objs is const.INVALID: + if objs == const.INVALID: return const.INVALID elif len(objs) != 1: # Ambiguous or non-existent @@ -602,7 +601,7 @@ def __call__(self, objset, epoch_now): objs = self.objs(objset, epoch_now) else: objs = self.objs - if objs is const.INVALID: + if objs == const.INVALID: return const.INVALID elif len(objs) != 1: # Ambiguous or non-existent @@ -613,10 +612,10 @@ def __call__(self, objset, epoch_now): def get_expected_input(self, should_be): raise NotImplementedError() - if should_be is None: - should_be = sg.random_attr(self.attr_type) - objs = sg.Object([should_be]) - return [objs] + #if should_be is None: + # should_be = sg.random_attr(self.attr_type) + #objs = sg.Object([should_be]) + #return [objs] class Exist(Operator): @@ -704,7 +703,7 @@ def __str__(self): def __call__(self, objset, epoch_now): statement_true = self.statement(objset, epoch_now) - if statement_true is const.INVALID: + if statement_true == const.INVALID: if self.invalid_as_false: statement_true = False else: @@ -799,7 +798,7 @@ def __call__(self, objset, epoch_now): attr1 = self.attr1(objset, epoch_now) attr2 = self.attr2(objset, epoch_now) - if (attr1 is const.INVALID) or (attr2 is const.INVALID): + if (attr1 == const.INVALID) or (attr2 == const.INVALID): return const.INVALID else: return attr1 == attr2 @@ -812,8 +811,8 @@ def get_expected_input(self, should_be, objset, epoch_now): attr1_value = self.attr1(objset, epoch_now) attr2_value = self.attr2(objset, epoch_now) - attr1_fixed = attr1_value is not const.INVALID - attr2_fixed = attr2_value is not const.INVALID + attr1_fixed = attr1_value != const.INVALID + attr2_fixed = attr2_value != const.INVALID if attr1_fixed: assert attr1_value.has_value