新网创想网站建设,新征程启航

为企业提供网站建设、域名注册、服务器等服务

DialoGPT是什么

本篇内容介绍了“DialoGPT是什么”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!

创新互联建站专注于淳安企业网站建设,响应式网站,商城开发。淳安网站建设公司,为淳安等地区提供建站服务。全流程按需网站制作,专业设计,全程项目跟踪,创新互联建站专业和态度为您提供的服务

引言

Large-scale pretraining for dialogue

DialoGPT是基于GPT-2的对话生成预训练模型,在reddit数据集上训练

假定已经设置好环境,

在eval_util.py中增加 inference函数

def inference_model_results(model, tokenizer, inference_dataloader, args):

# use the same signature with eval_model_generation

logger.info('compute eval model loss, using eval mode, '

'please change it back to train after calling this function')

model.eval()

tot_sample = []

with torch.no_grad():

for step, batch in enumerate(inference_dataloader):

batch = tuple(t.to(args.device) for t in batch)

input_ids, position_ids, token_ids, label_ids, src_len, _ = batch

if args.no_token_id:

token_ids = None

n_sample = input_ids.shape[0]

logits = model.inference(input_ids, position_ids, token_ids)

def decode(batch_data, tokenizer, input_flag):

results = []

batch_data = batch_data.cpu().data.numpy()

for one_logits in batch_data: # [sentence_len, vocabulary_size]

if not input_flag:

word_ids = np.argmax(one_logits, axis=1)

else:

word_ids = one_logits

words = []

for id in word_ids:

if tokenizer.decoder[id] != "<|endoftext|>":

words.append(tokenizer.decoder[id])

else:

break

output_words = []

for word in words:

output_words.append(word[1:]) if word.startswith("Ġ") else output_words.append(word)

results.append(" ".join(output_words))

return results

posts = decode(input_ids, tokenizer, True)

inferences = decode(logits, tokenizer, False)

tot_sample.append(n_sample)

logger.info("model inference results")

for index in range(len(posts)):

print("post: ", posts[index])

print("inference: ", inferences[index])

# print(inferences)

break

# todo

return None

在modeling_gpt2.py中class GPT2LMHeadModel(GPT2PreTrainedModel)中增加inference函数

def inference(self, input_ids, position_ids=None, token_type_ids=None, past=None):

hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)

lm_logits = self.lm_head(hidden_states)

return lm_logits

自定义inference_LSP.py 文件

文件内容

# Copyright (c) Microsoft Corporation.

# Licensed under the MIT license.

'''

* @Desc: train GPT2 from scratch/ fine tuning.

Modified based on Huggingface GPT-2 implementation

'''

import json

import os

import sys

import argparse

import logging

import time

import tqdm

import datetime

import torch

import numpy as np

from os.path import join

from torch.distributed import get_rank, get_world_size

from lsp_model import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, Adam

from gpt2_training.train_utils import load_model, boolean_string, set_lr, get_eval_list_same_length

from gpt2_training.eval_utils import eval_model_loss, inference_model_results

from data_loader import BucketingDataLoader, DynamicBatchingLoader, DistributedBucketingDataLoader

from gpt2_training.distributed import all_reduce_and_rescale_tensors, all_gather_list

os.environ['CUDA_VISIBLE_DEVICES'] = "0"

logging.basicConfig(

format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',

datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)

logger = logging.getLogger(__name__)

INF = 100000000

CACHE_EMPTY_STEP = 10000

EVAL_STEP = 10000

#########################################################################

# Prepare Parser

##########################################################################

parser = argparse.ArgumentParser()

parser.add_argument('--model_name_or_path', type=str, required=True,

help='pretrained model name or path to local checkpoint')

parser.add_argument("--seed", type=int, default=42)

parser.add_argument("--max_seq_length", type=int, default=128)

parser.add_argument("--init_checkpoint", type=str, required=True)

parser.add_argument("--inference_input_file", type=str, required=True)

parser.add_argument("--inference_batch_size", type=int, default=8)

parser.add_argument("--num_optim_steps", type=int, default=1000000,

help="new API specifies num update steps")

parser.add_argument("--fp16", type=boolean_string, default=True)

parser.add_argument("--normalize_data", type=boolean_string, default=True)

parser.add_argument("--loss_scale", type=float, default=0)

parser.add_argument("--no_token_id", type=boolean_string, default=True)

parser.add_argument("--log_dir", type=str, required=True)

# distributed

parser.add_argument('--local_rank', type=int, default=-1,

help='for torch.distributed')

parser.add_argument('--config', help='JSON config file')

# do normal parsing

args = parser.parse_args()

if args.config is not None:

# override argparse defaults by config JSON

opts = json.load(open(args.config))

for k, v in opts.items():

if isinstance(v, str):

# PHILLY ENV special cases

if 'PHILLY_JOB_DIRECTORY' in v:

v = v.replace('PHILLY_JOB_DIRECTORY',

os.environ['PHILLY_JOB_DIRECTORY'])

elif 'PHILLY_LOG_DIRECTORY' in v:

v = v.replace('PHILLY_LOG_DIRECTORY',

os.environ['PHILLY_LOG_DIRECTORY'])

setattr(args, k, v)

# command line should override config JSON

argv = sys.argv[1:]

overrides, _ = parser.parse_known_args(argv)

for k, v in vars(overrides).items():

if f'--{k}' in argv:

setattr(args, k, v)

setattr(args, 'local_rank', overrides.local_rank)

if args.local_rank == -1:

logger.info('CUDA available? {}'.format(str(torch.cuda.is_available())))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n_gpu = torch.cuda.device_count()

args.device, args.n_gpu = device, n_gpu

else:郑州妇科医院哪家好 http://www.120zzzy.com/

# distributed training

torch.cuda.set_device(args.local_rank)

device = torch.device("cuda", args.local_rank)

# Initializes the distributed backend which will take care of

# sychronizing nodes/GPUs

torch.distributed.init_process_group(backend='nccl')

n_gpu = torch.distributed.get_world_size()

args.device, args.n_gpu = device, 1

logger.info("device: {} n_gpu: {}, distributed training: {}, "

"16-bits training: {}".format(

device, n_gpu, bool(args.local_rank != -1), args.fp16))

timestamp = datetime.datetime.now().strftime('%Y-%m-%d%H%M%S')

log_dir = args.log_dir

logger.info('Input Argument Information')

args_dict = vars(args)

for a in args_dict:

logger.info('%-28s %s' % (a, args_dict[a]))

#########################################################################

# Prepare Data Set

##########################################################################

print("Prepare Data")

enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)

config = GPT2Config.from_json_file(

join(args.model_name_or_path, 'config.json'))

inference_dataloader_loss = DynamicBatchingLoader(

args.inference_input_file, enc, args.normalize_data,

args.inference_batch_size, args.max_seq_length)

inference_dataloader_gen = get_eval_list_same_length(

args.inference_input_file, enc, args.inference_batch_size, True)

# eval_dataloader_loss = DynamicBatchingLoader(

# args.eval_input_file, enc, args.normalize_data,

# args.eval_batch_size, args.max_seq_length)

#

# eval_dataloader_gen = get_eval_list_same_length(

# args.eval_input_file, enc, args.eval_batch_size, True)

#########################################################################

# Prepare Model

##########################################################################

print("Prepare Model")

logger.info("Prepare Model")

model = load_model(GPT2LMHeadModel(config), args.init_checkpoint,

args, verbose=True)

if args.local_rank != -1:

# when from scratch make sure initial models are the same

params = [p.data for p in model.parameters()]

all_reduce_and_rescale_tensors(params, float(torch.distributed.get_world_size()))

no_decay = ['bias', 'ln'] # no decay for bias and LayerNorm (ln)

#########################################################################

# Inference !

##########################################################################

print("Model inference")

logger.info("Model inference")

inference_logger = open(join(log_dir, 'inference_log.txt'), 'a+', buffering=1)

epoch = 0

if args.local_rank != -1:

n_gpu = 1

# todo modify loss out.

results = inference_model_results(model, enc, inference_dataloader_loss, args)

# todo output format

# print('{},{},{},{},{}'.format(epoch + 1, global_step + 1, step + 1, eval_loss, eval_ppl), file=inference_logger)

logger.info("inference_final_results:")

if results is None:

logger.info("current results are None")

else:

logger.info(results)

inference_logger.close()

python inference_LSP.py --model_name_or_path ./models/medium/ --init_checkpoint ./12_5_self_output/GPT2.1e-05.8.3gpu.2019-12-04225327/GP2-pretrain-step-50000.pkl --inference_input_file ./selfdata/attack_chatbot.tsv --log_dir inference_logs_dir/

Inference

python inference_LSP.py --model_name_or_path ./models/medium/ --init_checkpoint ./12_5_self_output/GPT2.1e-05.8.3gpu.2019-12-04225327/GP2-pretrain-step-50000.pkl --inference_input_file ./selfdata/attack_chatbot.tsv --log_dir inference_logs_dir/

validset.tsv:

–model_name_or_path ./models/medium/ --init_checkpoint ./12_5_self_output/GPT2.1e-05.8.3gpu.2019-12-04225327/GP2-pretrain-step-50000.pkl --inference_input_file ./selfdata/validset.tsv --log_dir inference_logs_dir/

./models/medium/medium_ft.pkl

“DialoGPT是什么”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识可以关注创新互联网站,小编将为大家输出更多高质量的实用文章!


网页名称:DialoGPT是什么
文章路径:http://wjwzjz.com/article/ghjeeg.html
在线咨询
服务热线
服务热线:028-86922220
TOP