diff --git a/flexgen/flex_opt.py b/flexgen/flex_opt.py index 24f87a90..389b5f5e 100644 --- a/flexgen/flex_opt.py +++ b/flexgen/flex_opt.py @@ -584,7 +584,9 @@ def __init__(self, config: Union[str, OptConfig], env: ExecutionEnv, path: str, - policy: Policy): + policy: Policy, + local: bool = False, + local_path: str = None): if isinstance(config, str): config = get_opt_config(config) self.config = config @@ -592,6 +594,8 @@ def __init__(self, self.path = path self.policy = policy self.num_gpu_batches = policy.num_gpu_batches + self.local = local + self.local_path = local_path layers = [] layers.append(InputEmbed(self.config, self.env, self.policy)) @@ -646,7 +650,7 @@ def init_weight(self, j): os.path.join(self.path, f"{self.config.name}-np"))) check_path = os.path.join(expanded_path, "decoder.embed_positions.weight") if not os.path.exists(check_path) and DUMMY_WEIGHT not in check_path: - download_opt_weights(self.config.name, self.path) + download_opt_weights(self.config.name, self.path, self.local, self.local_path) self.layers[j].init_weight(self.weight_home[j], expanded_path) @@ -1216,7 +1220,7 @@ def run_flexgen(args): f"hidden size (prefill): {hidden_size/GB:.3f} GB") print("init weight...") - model = OptLM(opt_config, env, args.path, policy) + model = OptLM(opt_config, env, args.path, policy, args.local, args.model) try: print("warmup - generate") @@ -1316,6 +1320,9 @@ def add_parser_arguments(parser): parser.add_argument("--overlap", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--local", action="store_true", + help="Whether to use local copy of the model weights. ") + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/flexgen/opt_config.py b/flexgen/opt_config.py index 57f2a159..b303b6f4 100644 --- a/flexgen/opt_config.py +++ b/flexgen/opt_config.py @@ -51,7 +51,7 @@ def hidden_bytes(self, batch_size, seq_len): def get_opt_config(name, **kwargs): if "/" in name: - name = name.split("/")[1] + name = name.split("/")[-1] name = name.lower() # Handle opt-iml-30b and opt-iml-max-30b @@ -216,21 +216,27 @@ def disable_hf_opt_init(): "_init_weights", lambda *args, **kwargs: None) -def download_opt_weights(model_name, path): +def download_opt_weights(model_name, path, local = False, local_path = None): from huggingface_hub import snapshot_download import torch - print(f"Load the pre-trained pytorch weights of {model_name} from huggingface. " - f"The downloading and cpu loading can take dozens of minutes. " - f"If it seems to get stuck, you can monitor the progress by " - f"checking the memory usage of this process.") + if not local: + print(f"Load the pre-trained pytorch weights of {model_name} from huggingface. " + f"The downloading and cpu loading can take dozens of minutes. " + f"If it seems to get stuck, you can monitor the progress by " + f"checking the memory usage of this process.") - if "opt" in model_name: - hf_model_name = "facebook/" + model_name - elif "galactica" in model_name: - hf_model_name = "facebook/" + model_name + if "opt" in model_name: + hf_model_name = "facebook/" + model_name + elif "galactica" in model_name: + hf_model_name = "facebook/" + model_name + + folder = snapshot_download(hf_model_name, allow_patterns="*.bin") + else: + print(f"Load the pre-trained pytorch weights of {model_name} from local path: " + f"{local_path}. The loading can take dozens of minutes.") + folder = local_path - folder = snapshot_download(hf_model_name, allow_patterns="*.bin") bin_files = glob.glob(os.path.join(folder, "*.bin")) if "/" in model_name: