Bug report
expansion_factor_real_data > 1 generates placeholder data and then truncates the inputs in loss_fn to the correct size, which is supposed to eliminate the placeholder data. In gradient accumulation, from train_step we reshape the inputs to introduce a gradient_accumulation_steps dimension.
If we first truncated then reshaped, this would work correctly. However, we reshape then truncate, which means later gradient accumulation steps use the placeholder data.
I believe max_checkify does not catch this issue because it happens too early in the process.
(Internally we're on an older fork of this codebase, so I apologize if this has been fixed already. I looked through the relevant code and it looked like it would have the same issue)
Logs/Output
No response
Environment Information
No response
Additional Context
No response
Bug report
expansion_factor_real_data> 1 generates placeholder data and then truncates the inputs inloss_fnto the correct size, which is supposed to eliminate the placeholder data. In gradient accumulation, fromtrain_stepwe reshape the inputs to introduce agradient_accumulation_stepsdimension.If we first truncated then reshaped, this would work correctly. However, we reshape then truncate, which means later gradient accumulation steps use the placeholder data.
I believe
max_checkifydoes not catch this issue because it happens too early in the process.(Internally we're on an older fork of this codebase, so I apologize if this has been fixed already. I looked through the relevant code and it looked like it would have the same issue)
Logs/Output
No response
Environment Information
No response
Additional Context
No response