-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scan and apply_layers #7901
base: master
Are you sure you want to change the base?
scan and apply_layers #7901
Conversation
you can't just import it, you need to setup import dir correctly. Take a look at https://github.com/pytorch/xla/blob/master/test/dynamo/test_dynamo_dynamic_shape.py#L1-L6 |
@JackCaoG ty. i followed your example and got it working. |
import json | ||
hlo_json = json.loads(ctx.hlo_json()) | ||
num_parameters = len(hlo_json["hostProgramShape"]["parameters"]) | ||
self.assertEqual(len(mapping), num_parameters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you expect both value to be 10?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately not. It looks like some integer values (e.g. values <= 2) are shared when you put multiple copies into the HLO, but values above 2 are not shared. So we don't necessarily get 10. In any case, the precise number of parameters seems to be an implementation detail that we can't reliably test.
@@ -1077,7 +1076,9 @@ class PyLoweringContext { | |||
at::ScalarType dtype = | |||
MaybeUpcastToHostTorchType(literal.shape().element_type()); | |||
at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype); | |||
results[param_ids[i]] = input; | |||
std::optional param_id = lowering_ctx.GetParameterId(device_data[i]); | |||
XLA_CHECK(param_id.has_value()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when would it not has value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When GetParameterId
receives a BackendData
that is not a parameter in this lowering context, it will return std::nullopt
. However, this loop is only iterating over parameters (line 1071, const std::vector<torch::lazy::BackendDataPtr>& device_data = lowering_ctx.GetParametersData();
), so we will expect all BackendData
there to have an ID. Seems good to enforce this invariant.
return input_data | ||
|
||
# Extract and stack the parameters into a pytree. | ||
params = [_extract_weights_dict(layer) for layer in layers] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if it is a dropout layer that parameters are more than just tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is a dropout layer that references tensors other than model parameters (for example, the dropout probability), then those tensors will be captured as an additional HLO parameter to the XlaComputation
object. As implemented now, apply_layers
and scan
will trace the first layer, and then use the same captured tensor for subsequent layers. This will be a problem if the user passes different dropout probabilities for say a sequence of dropout layers -- we'll instead incorrectly just keep using the first dropout's probability. I'll have to dig deeper and find a solution for this.
If there's a layer that references things other than tensors, then either that thing (e.g. a bool
field) will impact the traced HLO computation, in which case I need to add a verification that all layers trace to equivalent computations. Or that thing won't impact the traced computation, in which case it won't matter to us.
example_layer = deepcopy(next(iter(layers))) | ||
|
||
# Hollow out the weights and biases in the example layer. | ||
example_layer = example_layer.to_empty(device=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this not going to impact the cloned arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify this question -- I thought to_empty
is going to destroy the value inside example_layer
, so I deepcopy
it before to backup.
torch_xla/experimental/scan.py
Outdated
fn_output_carry_pytree, fn_output_y_pytree = flat_fn(*(fake_carry + fake_x)) | ||
|
||
# Later we'll use `fn_output_carry_spec` etc to turn flattened outputs back to a PyTree. | ||
fn_output_carry, fn_output_carry_spec = tree_flatten(fn_output_carry_pytree) | ||
assert fn_output_carry_spec == carry_spec | ||
fn_output_y, fn_output_y_spec = tree_flatten(fn_output_y_pytree) | ||
flat_y_len = len(fn_output_y) | ||
fn_outputs = fn_output_carry + fn_output_y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if there are in place updates to the tensor but it is not being return from the function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this and we give the wrong answer: https://github.com/tengyifei/playground/blob/master/scan_with_in_place_updates.ipynb
In the notebook, I wrote an approach to detect and prevent in place updates like that. TLDR is we'll have to trace every forward of each layer and verify that they're the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the latest commit, this will now fail with an assertion error instead of silently giving wrong results.
torch_xla/experimental/scan.py
Outdated
|
||
def step_fn(grad_carry, pytree: Tuple[torch.Tensor, torch.Tensor, | ||
torch.Tensor]): | ||
grad_y, carry, x = pytree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a typo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so -- pytree
is a tuple of the output grad at current step (grad_y
), carry at the current step (carry
), and input at current step (x
)
4c4d127
to
67cff9b
Compare
@tengyifei is this PR a 2.5 candidate? |
@miladm yes, I'd like to backport this to 2.5 after addressing the comments etc. |
ddd01a4
to
ea640ab
Compare
2f868fd
to
99d8e5a
Compare
52e43ac
to
ecfed50
Compare
889125d
to
424e9d3
Compare
d575169
to
b9b9d6d
Compare
Add the lowering of scan to HLO While op. Introduce apply_layers which can sequentially apply a bunch of layers using scan underneath. Beef up unit tests including linear layers and decoders. add regression test for parameter_id_tensor_mapping add test_apply_layers.py to test shell scripts correctly import decoder model from examples
b9b9d6d
to
f474497
Compare
Add the lowering of scan to HLO While op.
Introduce apply_layers which can sequentially apply a bunch of layers
using scan underneath.
Beef up unit tests including linear layers and decoders.