File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1116,12 +1116,11 @@ def _run_step(
11161116 self ._trainer_state , outputs = compiled_train_step_fn (self .trainer_state , input_batch )
11171117
11181118 n = self ._config .log_every_n_steps or 100
1119- if self .step % n == 0 or 0 <= self .step <= 5 :
1120- self ._step_log (
1121- "loss=%s aux=%s" ,
1122- outputs ["loss" ],
1123- jax .tree .map (lambda x : x .item () if x .ndim == 0 else f"T{ x .shape } " , outputs ["aux" ]),
1124- )
1119+ self ._step_log (
1120+ "loss=%s aux=%s" ,
1121+ outputs ["loss" ],
1122+ jax .tree .map (lambda x : x .item () if x .ndim == 0 else f"T{ x .shape } " , outputs ["aux" ]),
1123+ )
11251124
11261125 self .summary_writer (self .step , {"loss" : outputs ["loss" ], ** outputs ["summaries" ]})
11271126 # Aggregate summaries across evalers.
Original file line number Diff line number Diff line change @@ -413,7 +413,7 @@ def adamw_decoupled_learner_config(
413413 peak_lr : float ,
414414 max_step : int ,
415415 weight_decay : float ,
416- lr_warmup_steps : int = 2000 ,
416+ lr_warmup_steps : int = 50 ,
417417 alpha : float = 0.1 ,
418418 b1 : float = 0.9 ,
419419 b2 : float = 0.95 ,
@@ -451,7 +451,7 @@ def adastar_learner_config(
451451 * ,
452452 peak_lr : float ,
453453 max_step : int ,
454- lr_warmup_steps : int = 2000 ,
454+ lr_warmup_steps : int = 50 ,
455455 alpha : float = 0.005 ,
456456 weight_decay : float = 3.16e-4 ,
457457 b1 : float = 0.95 ,
Original file line number Diff line number Diff line change @@ -291,10 +291,12 @@ def get_trainer_kwargs(
291291 tokens_per_batch = TOKENS_PER_BATCH [version ]
292292 max_step = TOTAL_TOKENS [version ][model_size ] // tokens_per_batch
293293 max_sequence_length = MAX_SEQUENCE_LENGTH [version ]
294- train_batch_size = tokens_per_batch // max_sequence_length
294+ # train_batch_size = tokens_per_batch // max_sequence_length
295+ train_batch_size = 128
295296
296297 # Whether to use grouped query attention.
297298 num_kv_heads = None
299+ max_step = 300
298300 if version in (Version .V3 , Version .V3_TIKTOKEN ):
299301 num_kv_heads = 8
300302
@@ -412,6 +414,7 @@ def get_trainer_kwargs(
412414 max_sequence_length = max_sequence_length ,
413415 train_batch_size = train_batch_size ,
414416 max_step = max_step ,
417+ save_every_n_steps = 100 ,
415418 mesh_shape = mesh_shape_from_axes (data = - 1 , fsdp = 8 ),
416419 mesh_rules = (
417420 # Step time:
@@ -504,7 +507,7 @@ def get_trainer_kwargs(
504507 config_modifiers = [
505508 MeshShapeModifier .default_config ().set (
506509 # fsdp=8 is also ok, only 2% slower step time.
507- mesh_shape = mesh_shape_from_axes (data = - 1 , fsdp = 64 )
510+ mesh_shape = mesh_shape_from_axes (data = 1 , fsdp = 128 )
508511 ),
509512 RematSpecModifier .default_config ().set (
510513 remat_policies = {
You can’t perform that action at this time.
0 commit comments