Skip to content

Commit adb23a3

Browse files
authored
Merge pull request #305 from Modalities/conversion_modalities_to_huggingface
Checkpoint Conversion to HuggingFace (GPT2)
2 parents 5525864 + e74f5fb commit adb23a3

25 files changed

Lines changed: 2526 additions & 69 deletions

src/modalities/checkpointing/torch/torch_checkpoint_loading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def load_model_checkpoint(self, model: nn.Module, file_path: Path) -> nn.Module:
5151
if self.precision is not None and self.precision.value != model_state_dtype:
5252
warning(
5353
f"WARNING: Model checkpoint was stored with precision {model_state_dtype} "
54-
"but is loaded with precision {self.precision.value}."
54+
f"but is loaded with precision {self.precision.value}."
5555
)
5656

5757
# assign=True makes sure that the model is loaded with the same precision

src/modalities/conversion/__init__.py

Whitespace-only changes.

src/modalities/conversion/gpt2/__init__.py

Whitespace-only changes.
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# coding=utf-8
2+
# This code was copied and modified from the Llama implementation of the Hugging Face Transformers library.
3+
# The original code can be found at:
4+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/configuration_llama.py
5+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
6+
#
7+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
8+
# and OPT implementations in this library. It has been modified from its
9+
# original forms to accommodate minor architectural differences compared
10+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
11+
#
12+
# Licensed under the Apache License, Version 2.0 (the "License");
13+
# you may not use this file except in compliance with the License.
14+
# You may obtain a copy of the License at
15+
#
16+
# http://www.apache.org/licenses/LICENSE-2.0
17+
#
18+
# Unless required by applicable law or agreed to in writing, software
19+
# distributed under the License is distributed on an "AS IS" BASIS,
20+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21+
# See the License for the specific language governing permissions and
22+
# limitations under the License.
23+
"""LLaMA-like GPT2 model configuration"""
24+
25+
from transformers.configuration_utils import PretrainedConfig
26+
from transformers.modeling_rope_utils import rope_config_validation
27+
28+
29+
class GPT2Config(PretrainedConfig):
30+
r"""
31+
This is the configuration class to store the configuration of a [`GPT2Model`]. It is used to instantiate an GPT2
32+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
33+
defaults will yield a similar configuration to that of the LLaMA-7B.
34+
35+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36+
documentation from [`PretrainedConfig`] for more information.
37+
38+
39+
Args:
40+
vocab_size (`int`, *optional*, defaults to 32000):
41+
Vocabulary size of the GPT2 model. Defines the number of different tokens that can be represented by the
42+
`inputs_ids` passed when calling [`GPT2Model`]
43+
hidden_size (`int`, *optional*, defaults to 4096):
44+
Dimension of the hidden representations.
45+
intermediate_size (`int`, *optional*, defaults to 11008):
46+
Dimension of the MLP representations.
47+
num_hidden_layers (`int`, *optional*, defaults to 32):
48+
Number of hidden layers in the Transformer decoder.
49+
num_attention_heads (`int`, *optional*, defaults to 32):
50+
Number of attention heads for each attention layer in the Transformer decoder.
51+
num_key_value_heads (`int`, *optional*):
52+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56+
by meanpooling all the original heads within that group. For more details checkout [this
57+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
58+
`num_attention_heads`.
59+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60+
The non-linear activation function (function or string) in the decoder.
61+
max_position_embeddings (`int`, *optional*, defaults to 2048):
62+
The maximum sequence length that this model might ever be used with.
63+
initializer_range (`float`, *optional*, defaults to 0.02):
64+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
66+
The epsilon used by the rms normalization layers.
67+
use_cache (`bool`, *optional*, defaults to `True`):
68+
Whether or not the model should return the last key/values attentions (not used by all models). Only
69+
relevant if `config.is_decoder=True`.
70+
pad_token_id (`int`, *optional*):
71+
Padding token id.
72+
bos_token_id (`int`, *optional*, defaults to 1):
73+
Beginning of stream token id.
74+
eos_token_id (`int`, *optional*, defaults to 2):
75+
End of stream token id.
76+
pretraining_tp (`int`, *optional*, defaults to 1):
77+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
78+
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
79+
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
80+
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
81+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
82+
Whether to tie weight embeddings
83+
rope_theta (`float`, *optional*, defaults to 10000.0):
84+
The base period of the RoPE embeddings.
85+
rope_scaling (`Dict`, *optional*):
86+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
87+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
88+
accordingly.
89+
Expected contents:
90+
`rope_type` (`str`):
91+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
92+
'llama3'], with 'default' being the original RoPE implementation.
93+
`factor` (`float`, *optional*):
94+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
95+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
96+
original maximum pre-trained length.
97+
`original_max_position_embeddings` (`int`, *optional*):
98+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
99+
pretraining.
100+
`attention_factor` (`float`, *optional*):
101+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
102+
computation. If unspecified, it defaults to value recommended by the implementation, using the
103+
`factor` field to infer the suggested value.
104+
`beta_fast` (`float`, *optional*):
105+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
106+
ramp function. If unspecified, it defaults to 32.
107+
`beta_slow` (`float`, *optional*):
108+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
109+
ramp function. If unspecified, it defaults to 1.
110+
`short_factor` (`List[float]`, *optional*):
111+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
112+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
113+
size divided by the number of attention heads divided by 2
114+
`long_factor` (`List[float]`, *optional*):
115+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
116+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
117+
size divided by the number of attention heads divided by 2
118+
`low_freq_factor` (`float`, *optional*):
119+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
120+
`high_freq_factor` (`float`, *optional*):
121+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
122+
attention_bias (`bool`, *optional*, defaults to `False`):
123+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
124+
attention_dropout (`float`, *optional*, defaults to 0.0):
125+
The dropout ratio for the attention probabilities.
126+
mlp_bias (`bool`, *optional*, defaults to `False`):
127+
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
128+
head_dim (`int`, *optional*):
129+
The attention head dimension. If None, it will default to hidden_size // num_heads
130+
131+
```python
132+
>>> from transformers import GPT2Model, GPT2Config
133+
134+
>>> # Initializing a GPT2 with a llama-7b style configuration
135+
>>> configuration = GPT2Config()
136+
137+
>>> # Initializing a model from the llama-7b style configuration
138+
>>> model = GPT2Model(configuration)
139+
140+
>>> # Accessing the model configuration
141+
>>> configuration = model.config
142+
```"""
143+
144+
model_type = "modalities-gpt2"
145+
keys_to_ignore_at_inference = ["past_key_values"]
146+
# Default tensor parallel plan for base model `GPT2Model`
147+
base_model_tp_plan = {
148+
"layers.*.self_attn.q_proj": "colwise",
149+
"layers.*.self_attn.k_proj": "colwise",
150+
"layers.*.self_attn.v_proj": "colwise",
151+
"layers.*.self_attn.o_proj": "rowwise",
152+
"layers.*.mlp.gate_proj": "colwise",
153+
"layers.*.mlp.up_proj": "colwise",
154+
"layers.*.mlp.down_proj": "rowwise",
155+
}
156+
157+
def __init__(
158+
self,
159+
vocab_size=32000,
160+
hidden_size=4096,
161+
intermediate_size=11008,
162+
num_hidden_layers=32,
163+
num_attention_heads=32,
164+
num_key_value_heads=None,
165+
hidden_act="silu",
166+
max_position_embeddings=2048,
167+
initializer_range=0.02,
168+
rms_norm_eps=None,
169+
layer_norm_eps: float = 1e-06,
170+
layer_norm_bias: bool = True,
171+
layer_norm_elementwise_affine: bool = True,
172+
use_cache=True,
173+
pad_token_id=None,
174+
bos_token_id=1,
175+
eos_token_id=2,
176+
pretraining_tp=1,
177+
tie_word_embeddings=False,
178+
rope_theta=10000.0,
179+
rope_scaling=None,
180+
attention_bias=False,
181+
attention_dropout=0.0,
182+
mlp_bias=False,
183+
head_dim=None,
184+
**kwargs,
185+
):
186+
if rms_norm_eps is not None:
187+
raise ValueError("RMSNorm is not supported in GPT2 model.")
188+
self.vocab_size = vocab_size
189+
self.max_position_embeddings = max_position_embeddings
190+
self.hidden_size = hidden_size
191+
self.intermediate_size = intermediate_size
192+
self.num_hidden_layers = num_hidden_layers
193+
self.num_attention_heads = num_attention_heads
194+
195+
# for backward compatibility
196+
if num_key_value_heads is None:
197+
num_key_value_heads = num_attention_heads
198+
199+
self.num_key_value_heads = num_key_value_heads
200+
self.hidden_act = hidden_act
201+
self.initializer_range = initializer_range
202+
self.rms_norm_eps = rms_norm_eps
203+
self.layer_norm_eps = layer_norm_eps
204+
self.layer_norm_bias = layer_norm_bias
205+
self.layer_norm_elementwise_affine = layer_norm_elementwise_affine
206+
self.pretraining_tp = pretraining_tp
207+
self.use_cache = use_cache
208+
self.rope_theta = rope_theta
209+
self.rope_scaling = rope_scaling
210+
self.attention_bias = attention_bias
211+
self.attention_dropout = attention_dropout
212+
self.mlp_bias = mlp_bias
213+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
214+
# Validate the correctness of rotary position embeddings parameters
215+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
216+
if self.rope_scaling is not None and "type" in self.rope_scaling:
217+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
218+
rope_config_validation(self)
219+
220+
super().__init__(
221+
pad_token_id=pad_token_id,
222+
bos_token_id=bos_token_id,
223+
eos_token_id=eos_token_id,
224+
tie_word_embeddings=tie_word_embeddings,
225+
**kwargs,
226+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import os
2+
import shutil
3+
4+
5+
def _copy_model_files(output_dir: str):
6+
source_dir = os.path.dirname(__file__)
7+
modeling_gpt2_path = os.path.join(source_dir, "modeling_gpt2.py")
8+
configuration_gpt2_path = os.path.join(source_dir, "configuration_gpt2.py")
9+
shutil.copy(modeling_gpt2_path, output_dir)
10+
shutil.copy(configuration_gpt2_path, output_dir)
11+
12+
13+
def _change_modalities_import_to_relative_import(output_dir: str):
14+
target_modeling_file = os.path.join(output_dir, "modeling_gpt2.py")
15+
with open(target_modeling_file, "r") as file:
16+
content = file.read()
17+
content = content.replace("modalities.conversion.gpt2.configuration_gpt2", ".configuration_gpt2")
18+
with open(target_modeling_file, "w") as file:
19+
file.write(content)
20+
21+
22+
def transfer_model_code(output_dir: str):
23+
"""Copies the required model code to the output directory and replaces modalities imports.
24+
This allows the converted model to be used without the modalities package via:
25+
>>> from transformers import AutoModelForCausalLM
26+
>>> model = AutoModelForCausalLM.from_pretrained("path/to/converted/model", trust_remote_code=True)
27+
28+
Args:
29+
output_dir (str): Directory of the converted model.
30+
"""
31+
_copy_model_files(output_dir)
32+
_change_modalities_import_to_relative_import(output_dir)

0 commit comments

Comments
 (0)