BERT预训练模型字向量提取工具--使用BERT编码句子

Isabella ·
更新时间:2024-09-21
· 959 次阅读

本文将介绍两个使用BERT编码句子(从BERT中提取向量)的例子。

(1)BERT预训练模型字向量提取工具
本工具直接读取BERT预训练模型,从中提取样本文件中所有使用到字向量,保存成向量文件,为后续模型提供embdding。

本工具直接读取预训练模型,不需要其它的依赖,同时把样本中所有 出现的字符对应的字向量全部提取,后续的模型可以非常快速进行embdding
github完整源码

#!/usr/bin/env python # coding: utf-8 __author__ = 'xmxoxo' ''' BERT预训练模型字向量提取工具 版本: v 0.3.2 更新: 2020/3/25 11:11 git: https://github.com/xmxoxo/BERT-Vector/ ''' import argparse import tensorflow as tf from tensorflow.python import pywrap_tensorflow import numpy as np import os import sys import traceback import pickle gblVersion = '0.3.2' # 如果模型的文件名不同,可修改此处 model_name = 'bert_model.ckpt' vocab_name = 'vocab.txt' # BERT embdding提取类 class bert_embdding(): def __init__(self, model_path='', fmt='pkl'): # 模型和词表的文件名 ckpt_path = os.path.join(model_path, model_name) vocab_file = os.path.join(model_path, vocab_name) if not os.path.isfile(vocab_file): print('词表文件不存在,请检查...') #sys.exit() return # 从模型读出指定层 reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path) #param_dict = reader.get_variable_to_shape_map() self.emb = reader.get_tensor("bert/embeddings/word_embeddings") self.vocab = open(vocab_file,'r', encoding='utf-8').read().split("\n") print('embeddings size: %s' % str(self.emb.shape)) print('词表大小:%d' % len(self.vocab)) # 兼容不同格式 self.fmt=fmt # 取出指定字符的embdding,返回向量 def get_embdding (self, char): if char in self.vocab: index = self.vocab.index(char) return self.emb[index,:] else: return None # 根据字符串提取向量并保存到文件 def export (self, txt_all, out_file=''): # 过滤重复,形成字典 txt_lst = sorted(list(set(txt_all))) print('文本字典长度:%d, 正在提取字向量...' % len(txt_lst)) count = 0 # 可选择输出哪种格式 2020/3/25 if self.fmt=='pkl': print('正在保存为pkl格式文件...') # 使用字典存储,使用时更加方便。 2020/3/23 lst_vector = dict() for word in txt_lst: v = self.get_embdding(word) if not (v is None): count += 1 lst_vector[word] = v # 改为使用pickle保存文件 2020/3/23 with open(out_file, 'wb') as out: pickle.dump(lst_vector, out, 2) if self.fmt=='txt': print('正在保存为txt格式文件...') with open(out_file, 'w', encoding='utf-8') as out: for word in txt_lst: v = self.get_embdding(word) if not (v is None): count += 1 out.write(word + " " + " ".join([str(i) for i in v])+"\n") print('字向量共提取:%d个' % count) # get all files and floders in a path # fileExt: ['png','jpg','jpeg'] # return: # return a list ,include floders and files , like [['./aa'],['./aa/abc.txt']] @staticmethod def getFiles (workpath, fileExt = []): try: lstFiles = [] lstFloders = [] if os.path.isdir(workpath): for dirname in os.listdir(workpath) : file_path = os.path.join(workpath, dirname) if os.path.isfile(file_path): if fileExt: if dirname[dirname.rfind('.')+1:] in fileExt: lstFiles.append (file_path) else: lstFiles.append (file_path) if os.path.isdir( file_path ): lstFloders.append (file_path) elif os.path.isfile(workpath): lstFiles.append(workpath) else: return None lstRet = [lstFloders,lstFiles] return lstRet except Exception as e : return None # 增加批量处理目录下的某类文件 v 0.1.2 xmxoxo 2020/3/23 def export_path (self, path, ext=['csv','txt'], out_file=''): try: files = self.getFiles(path,ext) # 合并数据 txt_all = [] tmp = '' for fn in files[1]: print('正在读取数据文件:%s' % fn) with open(fn, 'r', encoding='utf-8') as f: tmp = f.read() txt_all += list(set(tmp)) txt_all = list(set(txt_all)) self.export(txt_all, out_file=out_file) except Exception as e : print('批量处理出错:') print('Error in get_randstr: '+ traceback.format_exc()) return None # 命令行 def main_cli (): parser = argparse.ArgumentParser(description='BERT模型字向量提取工具') parser.add_argument('-v', '--version', action='version', version='%(prog)s ' + gblVersion ) parser.add_argument('--model_path', default='', required=True, type=str, help='BERT预训练模型的目录') parser.add_argument('--in_file', default='', required=True, type=str, help='待提取的文件名或者目录名') parser.add_argument('--out_file', default='./bert_embedding.pkl', type=str, help='输出文件名') parser.add_argument('--ext', default=['csv','txt'], type=str, nargs='+', help='指定目录时读取的数据文件扩展名') parser.add_argument('--fmt', default='pkl', type=str, help='输出文件的格式,可设置txt或者pkl, 默认为pkl') args = parser.parse_args() # 预训练模型的目录 model_path = args.model_path # 输出文件名 out_file = args.out_file # 包含所有文本的内容 in_file = args.in_file # 指定的扩展名 ext = args.ext # 文件格式 fmt = args.fmt if not fmt in ['pkl', 'txt']: fmt='pkl' if fmt=='txt' and out_file[-4:]=='.pkl': out_file = out_file[:-3] + 'txt' if not os.path.isdir(model_path): print('模型目录不存在,请检查:%s' % model_path) sys.exit() if not (os.path.isfile(in_file) or os.path.isdir(in_file)): print('数据文件不存在,请检查:%s' % in_file) sys.exit() print('\nBERT 字向量提取工具 V' + gblVersion ) print('-'*40) bertemb = bert_embdding(model_path=model_path, fmt=fmt) # 针对文件和目录分别处理 2020/3/23 by xmxoxo if os.path.isfile(in_file): txt_all = open(in_file,'r', encoding='utf-8').read() bertemb.export(txt_all, out_file=out_file) if os.path.isdir(in_file): bertemb.export_path(in_file, ext=ext, out_file=out_file) if __name__ == '__main__': pass main_cli()

(2)使用BERT编码句子
本文将BERT进行了封装,我们可以直接输入句子并得到句子对应的向量。
如下所示:

from bert_encoder import BertEncoder be = BertEncoder() embedding = be.encode("新年快乐,恭喜发财,万事如意!") print(embedding) print(embedding.shape)

完整封装:
完整代码

# -*- coding:utf-8 -*- import os from bert import modeling import tensorflow as tf from bert import tokenization flags = tf.flags FLAGS = flags.FLAGS bert_path = r'chinese_L-12_H-768_A-12' root_path = os.getcwd() flags.DEFINE_string( "bert_config_file", os.path.join(bert_path, 'bert_config.json'), "The config json file corresponding to the pre-trained BERT model." ) flags.DEFINE_string("vocab_file", os.path.join(bert_path, 'vocab.txt'), "The vocabulary file that the BERT model was trained on.") flags.DEFINE_bool( "do_lower_case", True, "Whether to lower case the input text." ) flags.DEFINE_integer( "max_seq_length", 128, "The maximum total input sequence length after WordPiece tokenization." ) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) tokenizer = tokenization.FullTokenizer( vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) def data_preprocess(sentence): tokens = [] for i, word in enumerate(sentence): # 分词,如果是中文,就是分字 token = tokenizer.tokenize(word) tokens.extend(token) # 序列截断 if len(tokens) >= FLAGS.max_seq_length - 1: tokens = tokens[0:(FLAGS.max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志 ntokens = [] segment_ids = [] ntokens.append("[CLS]") # 句子开始设置CLS 标志 segment_ids.append(0) # append("O") or append("[CLS]") not sure! for i, token in enumerate(tokens): ntokens.append(token) segment_ids.append(0) ntokens.append("[SEP]") # 句尾添加[SEP] 标志 segment_ids.append(0) # append("O") or append("[SEP]") not sure! input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式 # print(input_ids) input_mask = [1] * len(input_ids) # print(input_mask) while len(input_ids) < FLAGS.max_seq_length: input_ids.append(0) input_mask.append(0) input_ids = [input_ids] return input_ids, input_mask class BertEncoder(object): def __init__(self): self.bert_model = modeling.BertModel(config=bert_config, is_training=False, max_seq_length=FLAGS.max_seq_length) tvars = tf.trainable_variables() (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, FLAGS.init_cheeckpoint) tf.train.init_from_checkpoint(FLAGS.init_cheeckpoint, assignment_map) self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) def encode(self, sentence): input_ids, input_mask = data_preprocess(sentence) return self.sess.run(self.bert_model.embedding_output, feed_dict={self.bert_model.input_ids:input_ids}) if __name__ == "__main__": be = BertEncoder() embedding = be.encode("新年快乐,恭喜发财,万事如意!") print(embedding) print(embedding.shape)

参考:
https://github.com/xmxoxo/BERT-Vector
https://github.com/lzphahaha/bert_encoder


作者:broccoli2



训练模型 训练 工具 模型 句子

需要 登录 后方可回复, 如果你还没有账号请 注册新账号
相关文章