-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodels.py
More file actions
161 lines (131 loc) · 5.67 KB
/
models.py
File metadata and controls
161 lines (131 loc) · 5.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import logging
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.models.roformer.modeling_roformer import (
RoFormerEmbeddings,
RoFormerModel,
RoFormerEncoder,
RoFormerOnlyMLMHead,
RoFormerForMaskedLM,
RoFormerLayer,
RoFormerAttention,
RoFormerIntermediate,
RoFormerOutput,
RoFormerSelfAttention,
)
from accelerate.logging import get_logger
logger = get_logger(__name__)
class JRoFormerEmbeddings(RoFormerEmbeddings):
"""Construct the embeddings from word and token_type embeddings."""
def __init__(self, config):
super().__init__(config)
self.word_embeddings = nn.Embedding(
config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id
)
self.token_type_embeddings = self.word_embeddings
class JRoFormerSelfAttention(RoFormerSelfAttention):
def __init__(self, config):
super().__init__(config)
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
class JRoFormerAttention(RoFormerAttention):
def __init__(self, config):
super().__init__(config)
self.self = JRoFormerSelfAttention(config)
class JRoFormerLayer(RoFormerLayer):
def __init__(self, config):
super().__init__(config)
self.attention = JRoFormerAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = RoFormerAttention(config)
self.intermediate = RoFormerIntermediate(config)
self.output = RoFormerOutput(config)
class JRoFormerEncoder(RoFormerEncoder):
def __init__(self, config):
super().__init__(config)
self.layer = nn.ModuleList([JRoFormerLayer(config) for _ in range(config.num_hidden_layers)])
class JRoFormerModel(RoFormerModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = JRoFormerEmbeddings(config)
if config.embedding_size != config.hidden_size:
self.embeddings_project = nn.Linear(
config.embedding_size, config.hidden_size
)
self.encoder = JRoFormerEncoder(config)
# Initialize weights and apply final processing
self.post_init()
class JRoFormerForMaskedLM(RoFormerForMaskedLM):
def __init__(self, config):
super().__init__(config)
if config.is_decoder:
logger.warning(
"If you want to use `RoFormerForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
self.roformer = JRoFormerModel(config)
self.cls = RoFormerOnlyMLMHead(config)
# Initialize weights and apply final processing
self.post_init()
class JtransEncoder(nn.Module):
def __init__(self, dim, pretrain_path):
super().__init__()
self.encoder = JRoFormerModel.from_pretrained(pretrain_path)
self.config = self.encoder.config
self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size)
self.fc2 = nn.Linear(self.config.hidden_size, dim)
self.activation = nn.ReLU()
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
output = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
).last_hidden_state
if attention_mask is not None:
mask_expanded = attention_mask.unsqueeze(-1).expand(output.size()).half()
masked_output = output * mask_expanded
sum_masked_output = torch.sum(masked_output, dim=1)
sum_attention_mask = torch.sum(mask_expanded, dim=1)
pooled_output = sum_masked_output / sum_attention_mask
else:
pooled_output = torch.mean(output, dim=1)
output = self.fc1(pooled_output.to(self.fc1.weight.dtype))
output = self.activation(output)
output = self.fc2(output)
return output
class JtransPairEncoder(nn.Module):
def __init__(self, pretrain_path="../models/cebin"):
super().__init__()
self.encoder = JRoFormerModel.from_pretrained(pretrain_path)
self.config = self.encoder.config
self.fc = nn.Linear(self.config.hidden_size, self.config.hidden_size)
self.cls = nn.Linear(self.config.hidden_size, 1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
output = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
).last_hidden_state
if attention_mask is not None:
mask_expanded = attention_mask.unsqueeze(-1).expand(output.size()).half()
masked_output = output * mask_expanded
sum_masked_output = torch.sum(masked_output, dim=1)
sum_attention_mask = torch.sum(mask_expanded, dim=1)
pooled_output = sum_masked_output / sum_attention_mask
else:
pooled_output = torch.mean(output, dim=1)
output = self.fc(pooled_output.to(self.fc.weight.dtype))
output = self.relu(output)
output = self.cls(output)
output = self.sigmoid(output)
return output