Skip to content
This repository has been archived by the owner on Mar 5, 2024. It is now read-only.

Commit

Permalink
Merge pull request #43 from dcsil/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
ztjdavid authored Mar 7, 2023
2 parents 034de6a + 6a7d616 commit 26f4e98
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
36 changes: 36 additions & 0 deletions server/AI/data_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# This file is used for handel data in the database and send to the AI for prediction
from transformers import TrainingArguments, Trainer
import os
import torch
from torch.utils.data import Dataset
#
class WardrobeDataset(Dataset):
def __init__(self, weather_lst, occasion_lst, color_lst, budget_lst, style_lst, tokenizer):
self.input_ids = []
self.attention_mask = []
self.labels = []
# self.map_label = label_maps

for weather, occasion, color, budget, style in zip(weather_lst, occasion_lst, color_lst, budget_lst, style_lst):
# prep_txt = f'<startoftext>Content: {txt}\nLabel: {self.map_label[label]}<endoftext>'

prep_txt = f"Today’s weather is {weather}. I’m having a {occasion}. I prefer my clothing " \
f"color in {color}. Please give my an outfit in {style}. " \
f"Please suggest clothes that in budget {budget} if not selected " \
f"from my wardrobe"

encodings_dict = tokenizer(prep_txt, truncation=True, padding="max_length")

self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
self.attention_mask.append(torch.tensor(encodings_dict['attention_mask']))

def __len__(self):
return len(self.input_ids)

def __getitem__(self, idx):
dic = {
'input_ids': self.input_ids[idx],
'attention_mask': self.attention_mask[idx]
}
# return self.input_ids[idx], self.attention_mask[idx], self.labels[idx]
return dic
31 changes: 31 additions & 0 deletions server/AI/recommand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from data_classes import WardrobeDataset
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
import numpy as np
from transformers import TrainingArguments, Trainer
import os
import torch
from torch.utils.data import Dataset

def recommand_outfit(weather_lst, occasion_lst, color_lst, budget_lst, style_lst):
'''
:param weather_lst:
:param occasion_lst:
:param color_lst:
:param budget_lst:
:param style_lst:
:return:
This function is used to recommand outfit based on the user's preferences
'''
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
tokenizer.pad_token = tokenizer.eos_token
data = WardrobeDataset(weather_lst, occasion_lst, color_lst, budget_lst, style_lst, tokenizer)
# Load the model
model = AutoModelForSequenceClassification.from_pretrained("./outfit_recommand_model")
trainer = Trainer(model=model)
os.environ["WANDB_DISABLED"] = "true"
predictions = trainer.predict(data)

return np.argmax(predictions.predictions, axis=-1)

0 comments on commit 26f4e98

Please sign in to comment.