Skip to content

Commit

Permalink
refactor inference function
Browse files Browse the repository at this point in the history
  • Loading branch information
damianoazzolini committed Nov 28, 2024
1 parent 7f8e0e7 commit d5c48b3
Showing 1 changed file with 88 additions and 80 deletions.
168 changes: 88 additions & 80 deletions pastasolver/asp_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,91 @@ def compute_minimal_set_facts(self) -> None:
if el != '':
self.cautious_consequences.append(el)

def _print_worlds_info(self) -> None:
"""
Prints worlds info.
"""
print(utils.RED + "lp" + utils.END + utils.YELLOW + " up" + utils.END)
for el in self.prob_facts_dict:
print(f"{el} : {self.prob_facts_dict[el]}")
print("_"*50)
for el in self.prob_facts_dict:
print(el, end="\t")
print("#pf", end="\t")
if not self.evidence:
print("LP/UP\tProbability")
else:
print("Probability")
lp_count = 0
up_count = 0
for el in sorted(self.model_handler.worlds_dict, key= lambda x: x.count('1')):
for val in el:
print(f"{val}", end="\t")
print(f"{el.count('1')}", end="\t")
if not self.evidence:
if self.model_handler.worlds_dict[el].model_query_count > 0 and \
self.model_handler.worlds_dict[el].model_not_query_count == 0:
print(utils.RED + "LP\t", end = "")
lp_count = lp_count + 1
elif self.model_handler.worlds_dict[el].model_query_count > 0 and \
self.model_handler.worlds_dict[el].model_not_query_count > 0:
print(utils.YELLOW + "UP\t", end="")
up_count = up_count + 1
else:
print("-\t", end="")
print(self.model_handler.worlds_dict[el].prob, end="")
if self.model_handler.worlds_dict[el].model_query_count > 0 and not self.evidence:
print(utils.END)
else:
print("")

print(f"Total number of worlds that contribute to the probability: {lp_count + up_count}")
if not self.evidence:
print(f"Only LP: {lp_count}, Only UP: {up_count}")

def _handle_missing_worlds(self) -> float:
"""
Handle missing worlds.
"""
if len(self.model_handler.worlds_dict) == 0 and len(self.prob_facts_dict) > 0:
self.lower_probability_query = 0
self.upper_probability_query = 0
utils.print_pathological_program()
return 0

missing = []
if self.pedantic:
ks = sorted(self.model_handler.worlds_dict.keys())
l : 'list[int]' = []
for el in ks:
if el != '':
l.append(int(el,2))
missing = sorted(set(range(0, 2**len(self.prob_facts_dict))).difference(l), key=lambda x: bin(x)[2:].count('1'))

ntw = len(self.model_handler.worlds_dict) + 2**(len(self.prob_facts_dict) - len(self.cautious_consequences))
nw = 2**len(self.prob_facts_dict)

# TODO: check this case
if len(self.cautious_consequences) > 0 and (ntw != nw) and not self.xor and not self.upper:
utils.print_inconsistent_program(self.stop_if_inconsistent)

if self.stop_if_inconsistent and not self.normalize_prob and len(self.prob_facts_dict) > 0:
res = ""
for el in missing:
s = "0"*(len(self.prob_facts_dict) - len(bin(el)[2:])) + bin(el)[2:]
i = 0
res = res + s + "{ "
for el in self.prob_facts_dict:
if s[i] == '1':
res += el + " "
i = i + 1
res += "}\n"
if self.pedantic:
utils.print_error_and_exit(f"Found {len(missing)} worlds without answer sets: {missing}\n{res[:-1]}.")
else:
utils.print_error_and_exit(f"Found {2**len(self.prob_facts_dict) - len(self.model_handler.worlds_dict)} worlds without answer sets.")

return sum([x.prob for x in self.model_handler.worlds_dict.values()])

def compute_probabilities(self) -> None:
'''
Expand All @@ -187,98 +272,21 @@ def compute_probabilities(self) -> None:
handle.get() # type: ignore

self.normalizing_factor = 1
# print(self.model_handler.worlds_dict)


if len(self.model_handler.worlds_dict) != 2**len(self.prob_facts_dict):
if len(self.model_handler.worlds_dict) == 0 and len(self.prob_facts_dict) > 0:
self.lower_probability_query = 0
self.upper_probability_query = 0
utils.print_pathological_program()
return

missing = []
if self.pedantic:
ks = sorted(self.model_handler.worlds_dict.keys())
l : 'list[int]' = []
for el in ks:
if el != '':
l.append(int(el,2))
missing = sorted(set(range(0, 2**len(self.prob_facts_dict))).difference(l), key=lambda x: bin(x)[2:].count('1'))

ntw = len(self.model_handler.worlds_dict) + 2**(len(self.prob_facts_dict) - len(self.cautious_consequences))
nw = 2**len(self.prob_facts_dict)

# TODO: check this case
if len(self.cautious_consequences) > 0 and (ntw != nw) and not self.xor and not self.upper:
utils.print_inconsistent_program(self.stop_if_inconsistent)

if self.stop_if_inconsistent and not self.normalize_prob and len(self.prob_facts_dict) > 0:
res = ""
for el in missing:
s = "0"*(len(self.prob_facts_dict) - len(bin(el)[2:])) + bin(el)[2:]
i = 0
res = res + s + "{ "
for el in self.prob_facts_dict:
if s[i] == '1':
res += el + " "
i = i + 1
res += "}\n"
if self.pedantic:
utils.print_error_and_exit(f"Found {len(missing)} worlds without answer sets: {missing}\n{res[:-1]}.")
else:
utils.print_error_and_exit(f"Found {2**len(self.prob_facts_dict) - len(self.model_handler.worlds_dict)} worlds without answer sets.")

norm_fact = sum([x.prob for x in self.model_handler.worlds_dict.values()])
norm_fact = self._handle_missing_worlds()

if self.normalize_prob:
self.normalizing_factor = norm_fact

if self.pedantic:
# print(f"n missing {len(missing)}")
# print(self.inconsistent_worlds)
if self.normalize_prob:
print(f"Normalizing factor: {self.normalizing_factor}")
elif not self.stop_if_inconsistent:
print(f"P(inc) = {1 - norm_fact}")

if self.pedantic:
print(utils.RED + "lp" + utils.END + utils.YELLOW + " up" + utils.END)
for el in self.prob_facts_dict:
print(f"{el} : {self.prob_facts_dict[el]}")
print("_"*50)
for el in self.prob_facts_dict:
print(el, end="\t")
print("#pf", end="\t")
if not self.evidence:
print("LP/UP\tProbability")
else:
print("Probability")
lp_count = 0
up_count = 0
for el in sorted(self.model_handler.worlds_dict, key= lambda x: x.count('1')):
for val in el:
print(f"{val}", end="\t")
print(f"{el.count('1')}", end="\t")
if not self.evidence:
if self.model_handler.worlds_dict[el].model_query_count > 0 and \
self.model_handler.worlds_dict[el].model_not_query_count == 0:
print(utils.RED + "LP\t", end = "")
lp_count = lp_count + 1
elif self.model_handler.worlds_dict[el].model_query_count > 0 and \
self.model_handler.worlds_dict[el].model_not_query_count > 0:
print(utils.YELLOW + "UP\t", end="")
up_count = up_count + 1
else:
print("-\t", end="")
print(self.model_handler.worlds_dict[el].prob, end="")
if self.model_handler.worlds_dict[el].model_query_count > 0 and not self.evidence:
print(utils.END)
else:
print("")

print(f"Total number of worlds that contribute to the probability: {lp_count + up_count}")
if not self.evidence:
print(f"Only LP: {lp_count}, Only UP: {up_count}")
self._print_worlds_info()

self.lower_probability_query, self.upper_probability_query = self.model_handler.compute_lower_upper_probability(self.k_credal)
if self.normalizing_factor == 0:
Expand Down

0 comments on commit d5c48b3

Please sign in to comment.