基于bert实现文本多分类任务

Scarlett ·
更新时间:2024-11-15
· 549 次阅读

代码已上传至github https://github.com/danan0755/Bert_Classifier

数据来源cnews,可以通过百度云下载

链接:https://pan.baidu.com/s/1LzTidW_LrdYMokN---Nyag
提取码:zejw
 

数据格式如下:

bert中文预训练模型下载地址:

链接:https://pan.baidu.com/s/14JcQXIBSaWyY7bRWdJW7yg
提取码:mvtl

复制run_classifier.py,命名为run_cnews_cls.py。添加自定义的Processor

class MyProcessor(DataProcessor): def read_txt(self, data_dir, flag): with open(data_dir, 'r', encoding='utf-8') as f: lines = f.readlines() random.seed(0) random.shuffle(lines) # 取少量数据做训练 if flag == "train": lines = lines[0:5000] elif flag == "dev": lines = lines[0:500] elif flag == "test": lines = lines[0:100] return lines def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( self.read_txt(os.path.join(data_dir, "cnews.train.txt"), "train"), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( self.read_txt(os.path.join(data_dir, "cnews.val.txt"), "dev"), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples( self.read_txt(os.path.join(data_dir, "cnews.test.txt"), "test"), "test") def get_labels(self): """See base class.""" return ["体育", "娱乐", "家居", "房产", "教育", "时尚", "时政", "游戏", "科技", "财经"] def _create_examples(self, lines, set_type): """Creates examples for the training and dev sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = "%s-%s" % (set_type, i) split_line = line.strip().split("\t") text_a = tokenization.convert_to_unicode(split_line[1]) text_b = None if set_type == "test": label = "体育" else: label = tokenization.convert_to_unicode(split_line[0]) examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples

main方法里添加自定义的Processor

def main(_): tf.logging.set_verbosity(tf.logging.INFO) processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mrpc": MrpcProcessor, "xnli": XnliProcessor, "cnews": MyProcessor }

训练运行命令

python run_cnews_cls.py --task_name=cnews --do_train=true --do_eval=true --do_predict=false --data_dir=cnews --vocab_file=pretrained_model/chinese_L-12_H-768_A-12/vocab.txt --bert_config_file=pretrained_model/chinese_L-12_H-768_A-12/bert_config.json --init_checkpoint=pretrained_model/chinese_L-12_H-768_A-12/bert_model.ckpt --max_seq_length=128 --output_dir=model

运行测试命令

python run_cnews_cls.py --task_name=cnews --do_train=false --do_eval=false --do_predict=true --data_dir=cnews --vocab_file=pretrained_model/chinese_L-12_H-768_A-12/vocab.txt --bert_config_file=pretrained_model/chinese_L-12_H-768_A-12/bert_config.json --init_checkpoint=pretrained_model/chinese_L-12_H-768_A-12/bert_model.ckpt --max_seq_length=128 --output_dir=result

结果
INFO:tensorflow:  eval_accuracy = 0.93386775
INFO:tensorflow:  eval_loss = 0.33081177
INFO:tensorflow:  global_step = 468
INFO:tensorflow:  loss = 0.3427003


作者:永胜永胜



多分类任务 多分类 分类

需要 登录 后方可回复, 如果你还没有账号请 注册新账号
相关文章