-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP2][NF4Tensor][2/n] implement torch.chunk and other ops #150
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Didn't forget to review, will give it a thorough read this afternoon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Made a first pass and can do a second one tomorrow morning
test/dtypes/test_nf4.py
Outdated
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
def test_to_cpu(self): | ||
nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda')) | ||
nf4_tensor.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this is just testing against crashes or do also expect the nf4_tensor.device to be cpu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. this is testing against crashes but i will add assertion on nf4_tensor.device.type == 'cpu'
torch.as_strided(nf4_tensor, nf4_tensor.size(), stride, nf4_tensor.storage_offset()) | ||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
def test_pin_memory(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you mentioned this briefly last week but could you remind me how you figured out these would be the functions that needed to be tested. (I'm thinking ahead with a tutorial for someone who wants to upstream some new exotic dtpye and get it working with fsdp). That's probably a good candidate for what I mean by we should add another smoke test so we know for sure FSDP will work
So I ran the tests locally and they all worked and fast! So this gives me confidence the nf4 tensor now supports many new ops but it doesnt give me confidence that fsdp won't break in some way
I was hoping we could have a smoke test of the sort fsdp(torch.nn.Sequential(LinearNF4(64,64)))
that would ensure nothing breaks and that fsdp doesn't silently drop the dtype since that functionality wasn't tested for fsdp 1 and we had to rely on twitter to get that signal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree that we need a smoke test on fsdp(model)
. Not sure how to setup a multi-gpu test in torchao though. Is there some .ci files to change? Is there some example in torchAO? I am happy to fill in the actual logic into the template. As a reference, FSDP tests in pytorch are done like this pytorch/test/distributed/_composable/fsdp/test_fully_shard_training.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something identical should work the machines we have in CI, every commit is already running on 4 A10Gs linux.g5.12xlarge. No existing example since this is our first distributed test
Let's just do this, first thing we meet tomorrow
def noop_detach(func, *args, **kwargs): | ||
return args[0][0] | ||
|
||
|
||
@implements( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more of a n00b q to @drisspg : what's up with all the args[0]
I feel like there's some sort of contract I can't quite parse
EDIT: It's the NF4 tensor, could we add some comment somewhere to make this clearer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated PR with nf4tensor = args[0]
at the begining to make it clearer
self.scaler_block_size, | ||
self.scaler_mean, | ||
self.nf4, | ||
mesh.get_group().size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b q: what is this doing?
Also more generally I don't follow what the 2 fsdp tests are trying to do. I think in fsdp_post_all_gather
you are testing to make sure nf4 tensors are preserved and not silently casted to some other type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't follow what the 2 fsdp tests are trying to do
This is core logic in nf4tensor.py
. unit tests happens in another file test_nf4.py
fsdp_pre_all_gather
returns a tuple of two things
- tuple[0] are
quantized_scalers
,quantization_factor
andquantized_data
. They are input for all-gather - tuple[1] are
SubclassTensorArgs
,block_size
etc are metadata to reconstruct NF4Tensor.mesh.get_group().size()
is the group size for all-gather (how many gpus). it's helpful to restore NF4Tensor.size. Eg for 2 gpus, all-gathering tensor(512) will return tensor(512 x 2)
torchao/dtypes/nf4tensor.py
Outdated
scaler_mean = aten_op(args[0].scaler_mean, *args[1:], **kwargs) | ||
nf4 = aten_op(args[0].nf4, *args[1:], **kwargs) | ||
tensor_meta = SubclassTensorArgs( | ||
args[0].size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 This also confused me. I think what driss means is just give a human readable name to args[0] so its easier to read the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Made a first pass and can do a second one tomorrow morning
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
aten.detach.default, | ||
] | ||
) | ||
def nf4_detach(aten_op, args, kwargs=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we make that assumption that requires_grad=False
and detach
is a no-op, can we add an assertion that checks for args[0].requires_grad
?
Also, I am not sure that we need to detach all inner tensors. cc: @bdhirsh
torchao/dtypes/nf4tensor.py
Outdated
raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}") | ||
ratio = nf4_tensor.numel() // math.prod(new_size) | ||
|
||
assert nf4_tensor.quantized_scalers.size(0) % ratio == 0, f"quantized_scalers.numel() must be divisible by {ratio}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: These assertion messages preferably should include the values (i.e. both nf4_tensor.quantized_scalers.size(0)
and ratio
) so that they can be more actionable.
torchao/dtypes/nf4tensor.py
Outdated
quantization_factor = aten_op(nf4_tensor.quantization_factor, *(args[1:]), **kwargs) | ||
quantized_data = aten_op(nf4_tensor.quantized_data, *(args[1:]), **kwargs) | ||
return NF4Tensor( | ||
SubclassTensorArgs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I seem to see this pattern a lot where we construct SubclassTensorArgs
directly from an existing nf4_tensor
. Perhaps, consider making this into a helper to avoid the duplication.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha. not nit at all. added util function to keep the code dry: NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
assert ( | ||
quantized_scalers.untyped_storage().data_ptr() | ||
== out.quantized_scalers.untyped_storage().data_ptr() and | ||
quantization_factor.untyped_storage().data_ptr() | ||
== out.quantization_factor.untyped_storage().data_ptr() and | ||
quantized_data.untyped_storage().data_ptr() | ||
== out.quantized_data.untyped_storage().data_ptr() | ||
), f"Expects out's data to be the all-gather output" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may consider removing these assert
s (in the future) especially if tracing through this becomes an issue. In theory, NF4Tensor
should not need to make this kind of assert, but for now, it might be helpful for debugging as the FSDP extension is still in its early stages.
) | ||
) | ||
) and len(args) == 2: | ||
# Tensor.to(device, non_blocking) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean that if we tried to use __torch_dispatch__
, we would not be able to tell that it is simply .to(device, non_blocking=True)
without a dtype argument/dtype change?
What is the story for dequantization? Namely, what is the outer NF4Tensor
's dtype, and what happens when you call .to(dtype)
with that same dtype? (e.g. if NF4Tensor.dtype == torch.bfloat16
, what if you call NF4Tensor.to(torch.bfloat16)
?)
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
torchao/dtypes/nf4tensor.py
Outdated
|
||
NF4_OPS_TABLE: Dict[Any, Any] = {} | ||
|
||
INNER_TENSOR_NAMES_FOR_FSDP = ["quantized_scalers", "quantization_factor", "quantized_data"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I exclude two tiny tensors: nf4 (numel=16) and scaler_mean (numel=1)
when GPU > numel, we need to implement padding for inner tensors. it's not worth the time in my opinion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like it'd apply to more than just FSDP. Is that correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like it'd apply to more than just FSDP. Is that correct?
it applies general distributed case when we shard a single tensor to N GPUs. I can change the name to INNER_TENSOR_NAMES_FOR_SHARDING
if that's clearer
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
torchao/dtypes/nf4tensor.py
Outdated
assert nf4tensor.quantized_scalers.size(0) % ratio == 0, f"quantized_scalers.numel() must be divisible by {ratio}" | ||
quantized_scalers = aten_op(nf4tensor.quantized_scalers, [nf4tensor.quantized_scalers.size(0) // ratio], **kwargs) | ||
|
||
assert nf4tensor.quantization_factor.size(0) % ratio == 0, f"quantization_factor.size(0) must be divisible by {ratio}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe these asserts could be unified?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good suggestion. I removed duplicative asserts with for loop over inner tensors
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
torchao/dtypes/nf4tensor.py
Outdated
|
||
NF4_OPS_TABLE: Dict[Any, Any] = {} | ||
|
||
INNER_TENSOR_NAMES_FOR_SHARDING = ["quantized_scalers", "quantization_factor", "quantized_data"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is something FSDP2 requires any Tensor subclass to have defined?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for any Tensor subclass, we prefer reusing __tensor_flatten__
to lookup inner tensors. For NF4, we define INNER_TENSOR_NAMES_FOR_SHARDING as a subset of __tensor_flatten__
because scaler_mean
and nf4
are too tiny to shard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, isn't that something you could filter with a numel
based heuristic within FSDP itself instead of requiring some tensor subclasses to communicate it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the inner tensors that are sharded needs to match the torch.chunk
implementation in the subclass, so FSDP cannot necessarily determine the tensors to shard itself. (E.g., if FSDP filtered by numel but the subclass implemented torch.chunk
to still shard some tensor smaller than the numel threshold, then there would be a correctness issue.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to private const with underscore _INNER_TENSOR_NAMES_FOR_SHARDING
after discussion with @cpuhrsch
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the heroic work. Let's open up an issue with known gaps
yes, will open issue for the renaming work |
opened issues and linked them here |
why FSDP needs those ops
torch.chunk
/aten.split.Tensor
: dim0 sharding on parameterstorch.chunk(tensor, world_size, dim=0)
tensor.new_zeros
/aten.new_zeros.default
: allocate storage for padded params.tensor[:end_idx]
/aten.slice.Tensor
andtensor.copy_
: copy sharded params into padded paramstensor.view(-1)
/aten.view.default
: flatten ND tensors into 1Dtorch.as_strided(tensor, orig_size)
/aten.as_strided.default
: restore 1D tensors to NDtensor.pin_memory
: move cpu tensor to pin memory for nonblocking D2H copytensor.cpu()
: move gpu tensor to cpuunit test:
pytest test/dtypes/test_nf4.py
run fsdp in TorchTune
git clone https://github.com/weifengpy/torchtune.git
cd torchtune && pip install -e ".[dev]"
tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN>
tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config recipes/configs/llama2/7B_qlora_single_device.yaml max_steps_per_epoch=1
user flow and gaps
step 1
: load llama2/3 from HF checkpoints. gap is memory spikes inNF4Tensor.from_tensor
[NF4][FSDP2] DTensor + fused adam on cpu #205step 2
: training loopstep 3
: save checkpoint: verify ifDTensor(NF4Tensor).full_tensor
+torch.save
works for NF4Tensorstep 4
: load checkpoint to resume finetuning: verify iftorch.load
+DTensor(NF4Tensor).distributed
works