平台训练传参
传参说明
小波平台读取项目中的MLProject
配置文件,自动解析获取参数parameters
信息。
在页面填写参数后,平台会基于自定义的command
自动拼装参数,命令行运行脚本启动服务,传参形式完全自定义。
例MLProject参数配置如下:
entry_points:
main:
parameters:
mock_k: {type: int, default: 2}
mock_b: {type: int, default: 1}
training_data: {type: str, default: "data/get_started_train.csv"}
param_type: {type: str, default: "click"}
command: "python train.py -k={mock_k} -b={mock_b}"
配置中的parameters
会被自动解析到页面中,可以方便快捷的完成不同实验训练的参数设定。
四种常用的传参方法
1.click(推荐)
代码:
import click
@click.command()
@click.option("--mock_k", "-k", type=int, default=2, help="函数的权重系数")
@click.option("--mock_b", "-b", type=int, default=1, help="函数的截距")
@click.option("--training_data", "-td", type=str, default="data/get_started_train.csv", help="训练数据")
@click.option("--param_type", "-pt", type=str, default="click", help="用click方式传参")
def parse_click(mock_k, mock_b, training_data, param_type):
return biz(mock_k, mock_b, training_data, param_type)
def train(mock_k, mock_b, training_data, param_type):
# training
if __name__ == '__main__':
train_click()
配置:
command: "python train.py -k {mock_k} -b {mock_b} -td {training_data} -pt {param_type}"
2.tf.app.run
tensorflow提供的一种方便的解析方式。如果本身使用tensorflow进行训练,可以直接使用此模式进行传参。
代码:
import click
def parse_tf_app_run():
import tensorflow as tf
tf.app.flags.DEFINE_integer('mock_k', 2, '函数的权重系数')
tf.app.flags.DEFINE_integer('mock_b', 1, '函数的截距')
tf.app.flags.DEFINE_string('mock_b', "data/get_started_train.csv", '训练数据')
tf.app.flags.DEFINE_string('param_type', "click", "用click方式传参")
FLAGS = tf.app.flags.FLAGS
return biz(FLAGS.mock_k, FLAGS.mock_b, FLAGS.training_data, FLAGS.param_type)
def train(mock_k, mock_b, training_data, param_type):
# training
if __name__ == '__main__':
train_tf_app_run()
配置:
command: "python train.py -k={mock_k} -b={mock_b} -td={training_data} -pt={param_type}"
3.argparse(不建议)
注意: 平台将使用gunicorn启动服务,
gunicorn
与argparse
传参会出现冲突,如下图所示,不建议使用。 需要将parser.parse_args()
改写为vars(ap.parse_args(args=[]))
解决冲突后,方可正常使用。
一般使用bool, int, str, float这些基本类型,更复杂的需求可以通过str传入,然后手动解析。bool类型的解析比较特殊,传入任何值都会被解析成True,传入空值时才为False
代码:
import argparse
def parse_argparse():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--mock_k', type=int, default=2)
parser.add_argument('--mock_b', type=int, default=1)
parser.add_argument('--training_data', type=str, default="data/get_started_train.csv")
parser.add_argument("--param_type", type=str, default="click")
# 以下写法可能会导致命令行传参冲突,不要使用
# args = parser.parse_args()
args = vars(ap.parse_args(args=[]))
return biz(args.mock_k, args.mock_b, args.training_data, args.param_type)
def train(mock_k, mock_b, training_data, param_type):
# training
if __name__ == '__main__':
train_argparse()
配置:
command: "python train.py -k={mock_k} -b={mock_b} -td={training_data} -pt={param_type}"
4.sys.argv(不建议)
支持最基础的sys.argv,较为简单,不支持一些复杂用法。不建议使用
代码:
def train_sys_argv():
import sys
mock_k, mock_b, training_data, param_type = 1, 2, "data/get_started_train.csv", "click"
if sys.argv[1]:
mock_k = sys.argv[1]
if sys.argv[2]:
mock_b = sys.argv[2]
if sys.argv[3]:
training_data = sys.argv[3]
if sys.argv[4]:
param_type = sys.argv[4]
return biz(mock_k, mock_b, training_data, param_type)
配置:
command: "python train.py {mock_k} {mock_b} {training_data} {param_type}"