目录
定义
# coding: utf8
import argparse
"""
定义模型参数
"""
class Argument(object):
train_path = None
save_path = None
valid_path = None
epochs = None
char_emb_path = None
gpu_id = None
def __init__(self, args):
parser = self.__get_parser()
parse_args = parser.parse_args(args)
print(parse_args)
self.train_path = parse_args.train_path
self.save_path = parse_args.save_path
self.valid_path = parse_args.valid_path
self.epochs = parse_args.epochs
self.char_emb_path = parse_args.char_emb_path
self.gpu_id = parse_args.gpu_id
def __get_parser(self):
parser = argparse.ArgumentParser()
parser.add_argument("train_path")
parser.add_argument("save_path")
# 验证集
parser.add_argument("-v", "--valid_path", default=None)
# epoch
parser.add_argument("-e", "--epochs", default=100, type=int)
# char emb path
parser.add_argument("-c", "--char_emb_path", default=None)
# gpu
parser.add_argument("-g", "--gpu_id", default=0, type=int)
return parser
def show_args(self):
print("train_path:\t%s" % self.train_path)
print("save_path:\t%s" % self.save_path)
print("valid_path:\t%s" % self.valid_path)
print("epochs:\t%d" % self.epochs)
print("char_emb_path:\t%s" % self.char_emb_path)
print("gpu_id:\t%d" % self.gpu_id)
pass
使用
# coding: utf8
import sys
import argument
print(sys.argv)
args = argument.Argument(sys.argv[1:])
args.show_args()
结果
python train.py train.in model -v validation.in -c char_emb -e 10 -g 2
['train.py', 'train.in', 'model', '-v', 'validation.in', '-c', 'char_emb', '-e', '10', '-g', '2']
Namespace(char_emb_path='char_emb', epochs=10, gpu_id=2, save_path='model', train_path='train.in', valid_path='validation.in')
train_path: train.in
save_path: model
valid_path: validation.in
epochs: 10
char_emb_path: char_emb
gpu_id: 2
python train.py train.in model -v validation.in -c char_emb -e 10
['train.py', 'train.in', 'model', '-v', 'validation.in', '-c', 'char_emb', '-e', '10']
Namespace(char_emb_path='char_emb', epochs=10, gpu_id=0, save_path='model', train_path='train.in', valid_path='validation.in')
train_path: train.in
save_path: model
valid_path: validation.in
epochs: 10
char_emb_path: char_emb
gpu_id: 0