Skip to content

Commit

Permalink
Simple data answering system. #15
Browse files Browse the repository at this point in the history
  • Loading branch information
head-iie-vnr committed Jun 30, 2024
1 parent ca040f6 commit dc51622
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
55 changes: 55 additions & 0 deletions gen-ai/tabular/google-tapas/simple-qna-system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pandas as pd
import torch
from transformers import TapasTokenizer, TapasForQuestionAnswering

# Load the model and tokenizer
model_name = "google/tapas-base-finetuned-wtq"
tokenizer = TapasTokenizer.from_pretrained(model_name)
model = TapasForQuestionAnswering.from_pretrained(model_name)

# Load your table from a CSV file
table = pd.read_csv('simple-table-data.csv')

# Convert all cells in the table to strings
table = table.astype(str)


# Adjust the queries to match the CSV data
queries = [
"Which city does Bob belong to?",
"What is the age of Alice?",
"Where does David stay?",
"Who stays in Chicago?",
"Who is older, Bob or Eve?"
]

# Process each query separately to avoid issues with input size
for query in queries:
inputs = tokenizer(table=table, queries=query, padding="max_length", truncation=True, return_tensors="pt")

# Perform the forward pass
with torch.no_grad():
outputs = model(**inputs)

# Extract answers
predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
inputs,
outputs.logits.detach(),
outputs.logits_aggregation.detach()
)

# Find the predicted answer text
answers = []
for coordinates in predicted_answer_coordinates:
if len(coordinates) == 1:
# only a single cell:
answer = table.iat[coordinates[0]]
else:
# multiple cells
answer = " ".join([table.iat[coordinate] for coordinate in coordinates])
answers.append(answer)

# Print the answer for the current query
print(f"Question: {query}")
for answer in answers:
print(f"Answer: {answer}")
6 changes: 6 additions & 0 deletions gen-ai/tabular/google-tapas/simple-table-data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Name,Designation,Age,City
Alice,Professor,35,New York
Bob,Engineer,40,Los Angeles
Charlie,Doctor,50,Chicago
David,Artist,45,San Francisco
Eve,Lawyer,38,Boston

0 comments on commit dc51622

Please sign in to comment.