Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

upgrade to pytorch 0.4.0 #1126

Merged
merged 37 commits into from
May 29, 2018
Merged

upgrade to pytorch 0.4.0 #1126

merged 37 commits into from
May 29, 2018

Conversation

joelgrus
Copy link
Contributor

@joelgrus joelgrus commented Apr 24, 2018

I think you will be pleased with the results; on the whole the codebase feels quite a bit cleaner.

For the most part I tried to follow the migration guide:

http://pytorch.org/2018/04/22/0_4_0-migration-guide.html

at a high level that included the following

  • get rid of every reference to torch.autograd.Variable
  • replace the typechecks for tensors with their recommended (unidiomatic way), and factor out into a is_tensor function
  • dealing with the new 0-tensors everywhere (.item() instead of [0])
  • replacing the volatile flags on tensors with with torch.no_grad() context blocks (I think I put them in the right places)
  • all the initializer functions are now initializer_
  • I had to upgrade tensorboardX to 1.2 (the version we were on didn't work with pytorch 0.4)
  • replacing the clunky new() logic for making new tensors on specific devices with the new pytorch 0.4 ways of doing things old_tensor.new_zeros(10), old_tensor.new_tensor([x, y, z]) and so on

what I did not do:

  • any of the "writing device agnostic code" stuff in that guide. that's likely a good idea, but it deserves its own PR.

--

When I started there were 100s of warnings; there are still two left that I'm not 100% sure how to fix:

tests/models/sniff_test.py::SniffTest::test_coreference_resolution
  /home/joelg/miniconda3/envs/pytorch40/lib/python3.6/site-packages/torch/nn/modules/rnn.py:38: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1
    "num_layers={}".format(dropout, num_layers))

tests/models/sniff_test.py::SniffTest::test_machine_comprehension
  /home/joelg/miniconda3/envs/pytorch40/lib/python3.6/site-packages/torch/nn/modules/rnn.py:38: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1
    "num_layers={}".format(dropout, num_layers))

both only occur in the SniffTest, so maybe they are not worth worrying about. or maybe we need to update the models that those two rely on? in any case, I'm not losing too much sleep over them.


# finished_costs[batch_index] could be a list of 0-tensors or 1-tensors. We use .view(-1)
# to make sure they're treated as 1-tensors and then concatenated appropriately.
# TODO(joelgrus): make sure this is correct
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it shouldn't be right to me. Well, it looks like it probably does the right thing, but I think fixing this where the finished_costs gets created is the right solution, instead of this one. My guess is that something is getting created as a scalar now, where we really want it as a 1-dim tensor.

@schmmd schmmd added this to the AllenNLP Release 4.3 (Pytorch 0.4) milestone Apr 27, 2018
@joelgrus joelgrus changed the title WIP: upgrade to pytorch 0.4.0 upgrade to pytorch 0.4.0 May 8, 2018
@joelgrus
Copy link
Contributor Author

joelgrus commented May 8, 2018

only via the SniffTest, which passes.

off the top of my head, there's no reason why these changes should affect existing trained models (unless I did something incorrectly). but it's possible there are other changes in pytorch 0.4 that would.

Copy link
Contributor

@DeNeutoy DeNeutoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👏 👏 👏

Looks awesome, thanks for doing this - looks like it was a lot of boring work!

A few minor comments, otherwise LGTM. I'd also like to train at least one model so that we know we didn't break anything more fundamentally - happy to do that though.

@@ -39,6 +39,11 @@
START_SYMBOL = '@start@'
END_SYMBOL = '@end@'

def is_tensor(obj: Any) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment as to why you need to do it this way instead of with isinstance?

# Check if we've seen the span before.
if antecedent_span in spans_to_cluster_ids.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😞

costs = torch.cat(finished_costs[batch_index])
logprobs = torch.cat(finished_model_scores[batch_index])

print("size:", finished_costs[batch_index])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove print statement

@@ -65,7 +64,8 @@ def decode(self,
batch_scores = self._group_scores_by_batch(finished_states)
loss = 0
for scores in batch_scores.values(): # we don't care about the batch index, just the scores
loss += -util.logsumexp(torch.cat(scores))
print("scores", scores)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove print statement

@@ -65,7 +64,8 @@ def decode(self,
batch_scores = self._group_scores_by_batch(finished_states)
loss = 0
for scores in batch_scores.values(): # we don't care about the batch index, just the scores
loss += -util.logsumexp(torch.cat(scores))
print("scores", scores)
loss += -util.logsumexp(torch.cat([tensor.view(-1) for tensor in scores]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think doing a squeeze here is tighter - I was confused when I read this, as it isn't immediately obvious that it's a 1D tensor and that doing a view like this is the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the actual issue here (which I didn't dig into too deeply) is that sometimes the score was a 0-tensor and sometimes it was a 1-tensor, and the .view() is ensuring it's a 1-tensor.


print("size:", finished_costs[batch_index])
costs = torch.cat([tensor.view(-1) for tensor in finished_costs[batch_index]])
logprobs = torch.cat([tensor.view(-1) for tensor in finished_model_scores[batch_index]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use squeeze instead of view here?

Copy link
Member

@schmmd schmmd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joelgrus you'll want to update the torch version in setup.py as well.

@DeNeutoy
Copy link
Contributor

DeNeutoy commented May 8, 2018

Also to clarify - i'm not worried about our current trained models, i'm worried about our ability to train new models.

@schmmd
Copy link
Member

schmmd commented May 8, 2018

We want to do the following before merging:

  1. Train at least one model.
  2. Evaluate the existing models and make sure we get the expected results.

@schmmd
Copy link
Member

schmmd commented May 8, 2018

MC model evaluation matches.
TE model evaluation matches.

Copy link
Contributor

@matt-gardner matt-gardner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

elif isinstance(x, torch.autograd.Variable):
return sanitize(x.data)
elif isinstance(x, torch._TensorBase): # pylint: disable=protected-access
elif is_tensor(x):
# tensor needs to be converted to a list (and moved to cpu if necessary)
return x.cpu().tolist()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth putting a .data in here, or does that not matter? .tolist() already does more than that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it would do anything. as I understand it, the main reason to use .data is to "break free" of the computational graph; once you convert to a list you're not part of the computational graph anyway

@@ -331,7 +330,7 @@ def _get_action_embeddings(state: NlvrDecoderState,
padded_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions)
for action_list in actions_to_embed]
# Shape: (group_size, num_actions)
action_tensor = Variable(state.score[0].data.new(padded_actions).long())
action_tensor = state.score[0].new_tensor(padded_actions).long()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you pass in the type to the new_tensor method? dtype=torch.long

@@ -411,7 +411,7 @@ def _get_action_embeddings(state: WikiTablesDecoderState,
padded_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions)
for action_list in actions_to_embed]
# Shape: (group_size, num_actions)
action_tensor = Variable(state.score[0].data.new(padded_actions).long())
action_tensor = state.score[0].new_tensor(padded_actions).long()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype=torch.long

@@ -501,8 +501,8 @@ def _get_entity_action_logits(self,
padded_types = [common_util.pad_sequence_to_length(type_list, max_num_actions)
for type_list in entity_types]
# Shape: (group_size, num_actions)
action_tensor = Variable(state.score[0].data.new(padded_actions).long())
type_tensor = Variable(state.score[0].data.new(padded_types).long())
action_tensor = state.score[0].new_tensor(padded_actions).long()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype=torch.long

@@ -398,7 +397,7 @@ def forward(self, # type: ignore
if i in best_final_states:
best_action_indices = best_final_states[i][0].action_history[0]
if target_action_sequences is not None:
# Use a Tensor, not a Variable, to avoid a memory leak.
# Detach to avoid a memory leak.
targets = target_action_sequences[i].data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call .detach() instead of .data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was a little confused on this, here's what the migration docs say:

However, .data can be unsafe in some cases. Any changes on x.data wouldn’t be tracked by autograd, and the computed gradients would be incorrect if x is needed in a backward pass. A safer alternative is to use x.detach(), which also returns a Tensor that shares data with requires_grad=False, but will have its in-place changes reported by autograd if x is needed in backward.

to me that suggests that detach still retains some connection to the computational graph and might still result in a memory leak?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just asked about this in the pytorch slack, and I got this response:

"Variables returned by .detach() do not hold on to the computation graph. You can almost certainly use .detach() in places where you are using .data

"They do share a version counter with the original tensor. This doesn't require any extra memory usage. It just means that changes to the detached variable increment the (shared) version counter on original tensor. So if you do backprop on the original tensor, and it's needed for a gradient computation, you would get an error message instead of just incorrect gradients."

for name, param in self._model.named_parameters()}
self._optimizer.step()
for name, param in self._model.named_parameters():
param_updates[name].sub_(param.detach().data.cpu())
param_updates[name].sub_(param.data.cpu())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.detach() instead of .data

@@ -530,7 +513,7 @@ def test_flatten_and_batch_shift_indices(self):
[[2, 1, 0, 7],
[7, 7, 2, 3],
[0, 0, 4, 2]]])
indices = Variable(torch.LongTensor(indices))
indices = torch.from_numpy(indices).long()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype=torch.long

@@ -546,8 +529,7 @@ def test_batched_index_select(self):
targets = torch.ones([2, 10, 3]).cumsum(1) - 1
# Make the second batch double it's index so they're different.
targets[1, :, :] *= 2
indices = Variable(torch.LongTensor(indices))
targets = Variable(targets)
indices = torch.from_numpy(indices).long()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype=torch.long

@@ -568,8 +550,7 @@ def test_flattened_index_select(self):
targets = torch.ones([2, 6, 3]).cumsum(1) - 1
# Make the second batch double it's index so they're different.
targets[1, :, :] *= 2
indices = Variable(torch.LongTensor(indices))
targets = Variable(targets)
indices = torch.from_numpy(indices).long()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype=torch.long

@@ -244,5 +244,5 @@ def test_sparse_clip_grad(self):
# Now try to clip the gradients.
_ = sparse_clip_norm([embedding.weight], 1.5)
# Final norm should be 1.5
grad = embedding.weight.grad.data.coalesce()
self.assertAlmostEqual(grad._values().norm(2.0), 1.5, places=5) # pylint: disable=protected-access
grad = embedding.weight.grad.data.coalesce() # pylint: disable=no-member
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need the .data here still?

@schmmd
Copy link
Member

schmmd commented May 9, 2018

BiDAF trained too: http://beaker.allenai.org/ex/ex_fh0xt1u5ytst/tasks

@murphp15
Copy link
Contributor

murphp15 commented May 17, 2018

@schmmd On a number of the PR's I have noticed that you guys post links to beaker.
Is this something the public should not have access to? I would love to see the results :)

@matt-gardner
Copy link
Contributor

@murphp15, beaker is an experiment management system that we're building internally at AI2. We submit jobs to a cluster using beaker, and it keeps track of the experiments that we run. It's pretty nice. The plan is to make it open source eventually, but I don't know what the timeline is for that.

@schmmd
Copy link
Member

schmmd commented May 17, 2018

@murphp15 it's internal, but hopefully we'll be able to share external links someday. We're working out our plan on that currently.

@schmmd
Copy link
Member

schmmd commented May 25, 2018

We plan to merge this on Tuesday.

@joelgrus joelgrus merged commit ed63b7e into master May 29, 2018
@matt-gardner matt-gardner deleted the pytorch-0.4 branch June 13, 2018 18:31
gabrielStanovsky pushed a commit to gabrielStanovsky/allennlp that referenced this pull request Sep 7, 2018
* bump pytorch to 0.4 + fix sanitize

* remove check for Variable in block_orthogonal

* rename parameter split_size=

* remove checks for Variable

* fix more tests

* fixes

* get tests to pass

* fix warnings

* get rid of some of the Variables

* more tests passing

* more elimination of variables

* finish removing all Variables

* pylint and such

* a few fixes

* move torch.no_grad into model.forward_on_instances

* more pytorch 0.4 changes

* detach() -> data

* pylint

* fix bad tensor creation

* fix types

* remove print statement

* more 0.4 goodness

* factor out is_tensor

* add no_grad to elmo command

* cleanup

* remove TODO

* address PR feedback

* replace all() with item()

* more cleanup

* further cleanup

* remove Variable

* really fix merge conflict

* fix pylint
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants