Skip to content

Commit

Permalink
updata vis
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengjinaling committed Feb 29, 2024
1 parent e81a9c7 commit 2a48b98
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 16 deletions.
16 changes: 10 additions & 6 deletions DecisionNCE/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,20 @@ def __init__(self, logit_scale = 100, loss_type = "DecionNCE-T", **kwargs):
self.loss_type = loss_type
assert self.loss_type in ['DecionNCE-T', 'DecionNCE-P'], f"Unknow loss type: {loss_type}"

def forward(self, visual_features, text_features):
batch_size = visual_features.shape[0]


def get_reward_matrix(self, visual_features, text_features):
if self.loss_type == 'DecionNCE-T':
reward_matrix = get_reward_matrix_T(visual_features, text_features, logit_scale = self.logit_scale)
return get_reward_matrix_T(visual_features, text_features, logit_scale = self.logit_scale)
elif self.loss_type == 'DecionNCE-P':
reward_matrix = get_reward_matrix_P(visual_features, text_features, logit_scale = self.logit_scale)
return get_reward_matrix_P(visual_features, text_features, logit_scale = self.logit_scale)
else:
raise NotImplementedError


def forward(self, visual_features, text_features):

batch_size = visual_features.shape[0]
reward_matrix = self.get_reward_matrix(visual_features, text_features, logit_scale = self.logit_scale)

labels = torch.arange(batch_size, device=reward_matrix.device).long()
return ( F.cross_entropy(reward_matrix, labels) + \
F.cross_entropy(reward_matrix.t(), labels) ) / 2
Expand Down
8 changes: 8 additions & 0 deletions DecisionNCE/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def __init__(self, modelid="RN50", device="cuda"):
T.CenterCrop(self.model.visual.input_resolution),
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
)

def get_reward(self, visual_input, text_input):
visual_feature = self.encode_image(visual_input)
text_feature = self.encode_text(text_input)
return torch.nn.functional.cosine_similarity(visual_feature, text_feature, dim=-1)


def encode_image(self, visual_input):
if type(visual_input) != torch.Tensor:
Expand Down Expand Up @@ -91,6 +97,8 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
if 'model' in state_dict:
state_dict = state_dict['model']
model.load_state_dict(state_dict, strict=False)

print("========= Load Successfully ========")
return model


7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pip install -e .
### Usage

```python

import DecisionNCE
import torch
from PIL import Image
Expand All @@ -55,7 +56,7 @@ text = "Your Instruction Here"
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)

reward = model.get_reward(image, text) # please note that number of image and text should be the same
```

### API
Expand All @@ -64,12 +65,10 @@ with torch.no_grad():

Returns the DecisionNCE model specified by the model name returned by `decisionnce.available_models()`. It will download the model as necessary. The `name` argument should be `DecisionNCE-P` or `DecisionNCE-T`

The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU.
The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU.

---



The model returned by `decisionnce.load()` supports the following methods:

#### `model.encode_image(image: Tensor)`
Expand Down
4 changes: 2 additions & 2 deletions script/slurm_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ datapath=$3
OMP_NUM_THREADS=1 \
srun -p ${partition} -n ${gpus} --ntasks-per-node=${gpus} --cpus-per-task=14 --gres=gpu:${gpus} \
python -u DecisionNCE/main.py \
--image_path <your image folder path > \
--meta_file <path for data annotation >
--image_path <image path > \
--meta_file <data annotation path >
Loading

0 comments on commit 2a48b98

Please sign in to comment.