Skip to content

Commit 14f91e2

Browse files
authored
Merge pull request #1 from atom-zh/dev
Dev
2 parents fac5d2a + a32bd39 commit 14f91e2

45 files changed

Lines changed: 5655 additions & 0 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
*.egg-info/
24+
.installed.cfg
25+
*.egg
26+
MANIFEST
27+
.idea/
28+
29+
30+
# PyInstaller
31+
# Usually these files are written by a python script from a template
32+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33+
*.manifest
34+
*.spec
35+
36+
# Installer logs
37+
pip-log.txt
38+
pip-delete-this-directory.txt
39+
40+
# Unit test / coverage reports
41+
htmlcov/
42+
.tox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
.hypothesis/
50+
.pytest_cache/
51+
52+
# Translations
53+
*.mo
54+
*.pot
55+
56+
# Django stuff:
57+
*.log
58+
local_settings.py
59+
db.sqlite3
60+
61+
# Flask stuff:
62+
instance/
63+
.webassets-cache
64+
65+
# Scrapy stuff:
66+
.scrapy
67+
68+
# Sphinx documentation
69+
docs/_build/
70+
71+
# PyBuilder
72+
target/
73+
74+
# Jupyter Notebook
75+
.ipynb_checkpoints
76+
77+
# pyenv
78+
.python-version
79+
80+
# celery beat schedule file
81+
celerybeat-schedule
82+
83+
# SageMath parsed files
84+
*.sage.py
85+
86+
# Environments
87+
.env
88+
.venv
89+
env/
90+
venv/
91+
ENV/
92+
env.bak/
93+
venv.bak/
94+
95+
# Spyder project settings
96+
.spyderproject
97+
.spyproject
98+
99+
# Rope project settings
100+
.ropeproject
101+
102+
# mkdocs documentation
103+
/site
104+
105+
# mypy
106+
out/
107+

base/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# -*- coding: UTF-8 -*-
2+
# !/usr/bin/python
3+
# @time :2019/6/3 11:24
4+
# @author :Mo
5+
# @function :

base/embedding.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# -*- coding: UTF-8 -*-
2+
# !/usr/bin/python
3+
# @time :2019/6/3 11:29
4+
# @author :Mo
5+
# @function :embeddings of model, base embedding of random, word2vec or bert
6+
7+
from conf.path_config import path_embedding_vector_word2vec_char, path_embedding_vector_word2vec_word
8+
from conf.path_config import path_embedding_random_char, path_embedding_random_word
9+
from data_preprocess.text_preprocess import get_ngram
10+
from keras.layers import Add, Embedding, Lambda
11+
from gensim.models import KeyedVectors
12+
from keras.models import Input, Model
13+
import numpy as np
14+
import jieba
15+
import os
16+
17+
class BaseEmbedding:
18+
def __init__(self, hyper_parameters):
19+
self.len_max = hyper_parameters.get('len_max', 50) # 文本最大长度, 建议25-50
20+
self.embed_size = hyper_parameters.get('embed_size', 300) # 嵌入层尺寸
21+
self.vocab_size = hyper_parameters.get('vocab_size', 30000) # 字典大小, 这里随便填的,会根据代码里修改
22+
self.trainable = hyper_parameters.get('trainable', False) # 是否微调, 例如静态词向量、动态词向量、微调bert层等, random也可以
23+
self.level_type = hyper_parameters.get('level_type', 'char') # 还可以填'word'
24+
self.embedding_type = hyper_parameters.get('embedding_type', 'word2vec') # 词嵌入方式,可以选择'xlnet'、'bert'、'random'、'word2vec'
25+
26+
# 自适应, 根据level_type和embedding_type判断corpus_path
27+
if self.level_type == "word":
28+
if self.embedding_type == "random":
29+
self.corpus_path = hyper_parameters['embedding'].get('corpus_path', path_embedding_random_word)
30+
elif self.embedding_type == "word2vec":
31+
self.corpus_path = hyper_parameters['embedding'].get('corpus_path', path_embedding_vector_word2vec_word)
32+
elif self.embedding_type == "bert":
33+
raise RuntimeError("bert level_type is 'char', not 'word'")
34+
elif self.embedding_type == "xlnet":
35+
raise RuntimeError("xlnet level_type is 'char', not 'word'")
36+
elif self.embedding_type == "albert":
37+
raise RuntimeError("albert level_type is 'char', not 'word'")
38+
else:
39+
raise RuntimeError("embedding_type must be 'random', 'word2vec' or 'bert'")
40+
elif self.level_type == "char":
41+
if self.embedding_type == "random":
42+
self.corpus_path = hyper_parameters['embedding'].get('corpus_path', path_embedding_random_char)
43+
elif self.embedding_type == "word2vec":
44+
self.corpus_path = hyper_parameters['embedding'].get('corpus_path', path_embedding_vector_word2vec_char)
45+
elif self.embedding_type == "bert":
46+
self.corpus_path = hyper_parameters['embedding'].get('corpus_path', path_embedding_bert)
47+
elif self.embedding_type == "xlnet":
48+
self.corpus_path = hyper_parameters['embedding'].get('corpus_path', path_embedding_xlnet)
49+
elif self.embedding_type == "albert":
50+
self.corpus_path = hyper_parameters['embedding'].get('corpus_path', path_embedding_albert)
51+
else:
52+
raise RuntimeError("embedding_type must be 'random', 'word2vec' or 'bert'")
53+
elif self.level_type == "ngram":
54+
if self.embedding_type == "random":
55+
self.corpus_path = hyper_parameters['embedding'].get('corpus_path')
56+
if not self.corpus_path:
57+
raise RuntimeError("corpus_path must exists!")
58+
else:
59+
raise RuntimeError("embedding_type must be 'random', 'word2vec' or 'bert'")
60+
else:
61+
raise RuntimeError("level_type must be 'char' or 'word'")
62+
# 定义的符号
63+
self.ot_dict = {'[PAD]': 0,
64+
'[UNK]': 1,
65+
'[BOS]': 2,
66+
'[EOS]': 3, }
67+
self.deal_corpus()
68+
self.build()
69+
70+
def deal_corpus(self): # 处理语料
71+
pass
72+
73+
def build(self):
74+
self.token2idx = {}
75+
self.idx2token = {}
76+
77+
def sentence2idx(self, text, second_text=None):
78+
if second_text:
79+
second_text = "[SEP]" + str(second_text).upper()
80+
# text = extract_chinese(str(text).upper())
81+
text = str(text).upper()
82+
83+
if self.level_type == 'char':
84+
text = list(text)
85+
elif self.level_type == 'word':
86+
text = list(jieba.cut(text, cut_all=False, HMM=True))
87+
else:
88+
raise RuntimeError("your input level_type is wrong, it must be 'word' or 'char'")
89+
text = [text_one for text_one in text]
90+
len_leave = self.len_max - len(text)
91+
if len_leave >= 0:
92+
text_index = [self.token2idx[text_char] if text_char in self.token2idx else self.token2idx['[UNK]'] for
93+
text_char in text] + [self.token2idx['[PAD]'] for i in range(len_leave)]
94+
else:
95+
text_index = [self.token2idx[text_char] if text_char in self.token2idx else self.token2idx['[UNK]'] for
96+
text_char in text[0:self.len_max]]
97+
return text_index
98+
99+
def idx2sentence(self, idx):
100+
assert type(idx) == list
101+
text_idx = [self.idx2token[id] if id in self.idx2token else self.idx2token['[UNK]'] for id in idx]
102+
return "".join(text_idx)
103+
104+
105+
class RandomEmbedding(BaseEmbedding):
106+
def __init__(self, hyper_parameters):
107+
self.ngram_ns = hyper_parameters['embedding'].get('ngram_ns', [1, 2, 3]) # ngram信息, 根据预料获取
108+
# self.path = hyper_parameters.get('corpus_path', path_embedding_random_char)
109+
super().__init__(hyper_parameters)
110+
111+
def deal_corpus(self):
112+
token2idx = self.ot_dict.copy()
113+
count = 3
114+
if 'term' in self.corpus_path:
115+
with open(file=self.corpus_path, mode='r', encoding='utf-8') as fd:
116+
while True:
117+
term_one = fd.readline()
118+
if not term_one:
119+
break
120+
term_one = term_one.strip()
121+
if term_one not in token2idx:
122+
count = count + 1
123+
token2idx[term_one] = count
124+
125+
elif os.path.exists(self.corpus_path):
126+
with open(file=self.corpus_path, mode='r', encoding='utf-8') as fd:
127+
terms = fd.readlines()
128+
for term_one in terms:
129+
if self.level_type == 'char':
130+
text = list(term_one.replace(' ', '').strip())
131+
elif self.level_type == 'word':
132+
text = list(jieba.cut(term_one, cut_all=False, HMM=False))
133+
elif self.level_type == 'ngram':
134+
text = get_ngram(term_one, ns=self.ngram_ns)
135+
else:
136+
raise RuntimeError("your input level_type is wrong, it must be 'word', 'char', 'ngram'")
137+
for text_one in text:
138+
if text_one not in token2idx:
139+
count = count + 1
140+
token2idx[text_one] = count
141+
else:
142+
raise RuntimeError("your input corpus_path is wrong, it must be 'dict' or 'corpus'")
143+
self.token2idx = token2idx
144+
self.idx2token = {}
145+
for key, value in self.token2idx.items():
146+
self.idx2token[value] = key
147+
148+
def build(self, **kwargs):
149+
self.vocab_size = len(self.token2idx)
150+
self.input = Input(shape=(self.len_max,), dtype='int32')
151+
self.output = Embedding(self.vocab_size+1,
152+
self.embed_size,
153+
input_length=self.len_max,
154+
trainable=self.trainable,
155+
)(self.input)
156+
self.model = Model(self.input, self.output)
157+
158+
def sentence2idx(self, text, second_text=""):
159+
if second_text:
160+
second_text = "[SEP]" + str(second_text).upper()
161+
# text = extract_chinese(str(text).upper()+second_text)
162+
text =str(text).upper() + second_text
163+
if self.level_type == 'char':
164+
text = list(text)
165+
elif self.level_type == 'word':
166+
text = list(jieba.cut(text, cut_all=False, HMM=False))
167+
elif self.level_type == 'ngram':
168+
text = get_ngram(text, ns=self.ngram_ns)
169+
else:
170+
raise RuntimeError("your input level_type is wrong, it must be 'word' or 'char'")
171+
# text = [text_one for text_one in text]
172+
len_leave = self.len_max - len(text)
173+
if len_leave >= 0:
174+
text_index = [self.token2idx[text_char] if text_char in self.token2idx else self.token2idx['[UNK]'] for
175+
text_char in text] + [self.token2idx['[PAD]'] for i in range(len_leave)]
176+
else:
177+
text_index = [self.token2idx[text_char] if text_char in self.token2idx else self.token2idx['[UNK]'] for
178+
text_char in text[0:self.len_max]]
179+
return text_index
180+
181+
182+
class WordEmbedding(BaseEmbedding):
183+
def __init__(self, hyper_parameters):
184+
# self.path = hyper_parameters.get('corpus_path', path_embedding_vector_word2vec)
185+
super().__init__(hyper_parameters)
186+
187+
def build(self, **kwargs):
188+
self.embedding_type = 'word2vec'
189+
print("load word2vec start!")
190+
self.key_vector = KeyedVectors.load_word2vec_format(self.corpus_path, **kwargs)
191+
print("load word2vec end!")
192+
self.embed_size = self.key_vector.vector_size
193+
194+
self.token2idx = self.ot_dict.copy()
195+
embedding_matrix = []
196+
# 首先加self.token2idx中的四个[PAD]、[UNK]、[BOS]、[EOS]
197+
embedding_matrix.append(np.zeros(self.embed_size))
198+
embedding_matrix.append(np.random.uniform(-0.5, 0.5, self.embed_size))
199+
embedding_matrix.append(np.random.uniform(-0.5, 0.5, self.embed_size))
200+
embedding_matrix.append(np.random.uniform(-0.5, 0.5, self.embed_size))
201+
202+
for word in self.key_vector.index2entity:
203+
self.token2idx[word] = len(self.token2idx)
204+
embedding_matrix.append(self.key_vector[word])
205+
206+
# self.token2idx = self.token2idx
207+
self.idx2token = {}
208+
for key, value in self.token2idx.items():
209+
self.idx2token[value] = key
210+
211+
self.vocab_size = len(self.token2idx)
212+
embedding_matrix = np.array(embedding_matrix)
213+
self.input = Input(shape=(self.len_max,), dtype='int32')
214+
215+
self.output = Embedding(self.vocab_size,
216+
self.embed_size,
217+
input_length=self.len_max,
218+
weights=[embedding_matrix],
219+
trainable=self.trainable)(self.input)
220+
self.model = Model(self.input, self.output)

0 commit comments

Comments
 (0)