Skip to content

Commit afbaad3

Browse files
committed
Ref: Changes to make it work Python 0.5.3. Topology v5p-128. Model 7b-fuji.
1 parent 5437705 commit afbaad3

3 files changed

Lines changed: 12 additions & 10 deletions

File tree

axlearn/common/trainer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,12 +1116,11 @@ def _run_step(
11161116
self._trainer_state, outputs = compiled_train_step_fn(self.trainer_state, input_batch)
11171117

11181118
n = self._config.log_every_n_steps or 100
1119-
if self.step % n == 0 or 0 <= self.step <= 5:
1120-
self._step_log(
1121-
"loss=%s aux=%s",
1122-
outputs["loss"],
1123-
jax.tree.map(lambda x: x.item() if x.ndim == 0 else f"T{x.shape}", outputs["aux"]),
1124-
)
1119+
self._step_log(
1120+
"loss=%s aux=%s",
1121+
outputs["loss"],
1122+
jax.tree.map(lambda x: x.item() if x.ndim == 0 else f"T{x.shape}", outputs["aux"]),
1123+
)
11251124

11261125
self.summary_writer(self.step, {"loss": outputs["loss"], **outputs["summaries"]})
11271126
# Aggregate summaries across evalers.

axlearn/experiments/text/gpt/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def adamw_decoupled_learner_config(
413413
peak_lr: float,
414414
max_step: int,
415415
weight_decay: float,
416-
lr_warmup_steps: int = 2000,
416+
lr_warmup_steps: int = 50,
417417
alpha: float = 0.1,
418418
b1: float = 0.9,
419419
b2: float = 0.95,
@@ -451,7 +451,7 @@ def adastar_learner_config(
451451
*,
452452
peak_lr: float,
453453
max_step: int,
454-
lr_warmup_steps: int = 2000,
454+
lr_warmup_steps: int = 50,
455455
alpha: float = 0.005,
456456
weight_decay: float = 3.16e-4,
457457
b1: float = 0.95,

axlearn/experiments/text/gpt/fuji.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,12 @@ def get_trainer_kwargs(
291291
tokens_per_batch = TOKENS_PER_BATCH[version]
292292
max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch
293293
max_sequence_length = MAX_SEQUENCE_LENGTH[version]
294-
train_batch_size = tokens_per_batch // max_sequence_length
294+
# train_batch_size = tokens_per_batch // max_sequence_length
295+
train_batch_size = 128
295296

296297
# Whether to use grouped query attention.
297298
num_kv_heads = None
299+
max_step = 300
298300
if version in (Version.V3, Version.V3_TIKTOKEN):
299301
num_kv_heads = 8
300302

@@ -412,6 +414,7 @@ def get_trainer_kwargs(
412414
max_sequence_length=max_sequence_length,
413415
train_batch_size=train_batch_size,
414416
max_step=max_step,
417+
save_every_n_steps=100,
415418
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
416419
mesh_rules=(
417420
# Step time:
@@ -504,7 +507,7 @@ def get_trainer_kwargs(
504507
config_modifiers=[
505508
MeshShapeModifier.default_config().set(
506509
# fsdp=8 is also ok, only 2% slower step time.
507-
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64)
510+
mesh_shape=mesh_shape_from_axes(data=1, fsdp=128)
508511
),
509512
RematSpecModifier.default_config().set(
510513
remat_policies={

0 commit comments

Comments
 (0)