Skip to main content
Version: 2.17.1

平台训练传参

传参说明

小波平台读取项目中的MLProject配置文件,自动解析获取参数parameters信息。

在页面填写参数后,平台会基于自定义的command自动拼装参数,命令行运行脚本启动服务,传参形式完全自定义。

例MLProject参数配置如下:

entry_points:
main:
parameters:
mock_k: {type: int, default: 2}
mock_b: {type: int, default: 1}
training_data: {type: str, default: "data/get_started_train.csv"}
param_type: {type: str, default: "click"}
command: "python train.py -k={mock_k} -b={mock_b}"

配置中的parameters会被自动解析到页面中,可以方便快捷的完成不同实验训练的参数设定。

平台解析参数

四种常用的传参方法

1.click(推荐)

代码:

import click

@click.command()
@click.option("--mock_k", "-k", type=int, default=2, help="函数的权重系数")
@click.option("--mock_b", "-b", type=int, default=1, help="函数的截距")
@click.option("--training_data", "-td", type=str, default="data/get_started_train.csv", help="训练数据")
@click.option("--param_type", "-pt", type=str, default="click", help="用click方式传参")
def parse_click(mock_k, mock_b, training_data, param_type):
return biz(mock_k, mock_b, training_data, param_type)

def train(mock_k, mock_b, training_data, param_type):
# training

if __name__ == '__main__':
train_click()

配置:

command: "python train.py -k {mock_k} -b {mock_b} -td {training_data} -pt {param_type}"

2.tf.app.run

tensorflow提供的一种方便的解析方式。如果本身使用tensorflow进行训练,可以直接使用此模式进行传参。

代码:

import click

def parse_tf_app_run():
import tensorflow as tf
tf.app.flags.DEFINE_integer('mock_k', 2, '函数的权重系数')
tf.app.flags.DEFINE_integer('mock_b', 1, '函数的截距')
tf.app.flags.DEFINE_string('mock_b', "data/get_started_train.csv", '训练数据')
tf.app.flags.DEFINE_string('param_type', "click", "用click方式传参")
FLAGS = tf.app.flags.FLAGS
return biz(FLAGS.mock_k, FLAGS.mock_b, FLAGS.training_data, FLAGS.param_type)

def train(mock_k, mock_b, training_data, param_type):
# training

if __name__ == '__main__':
train_tf_app_run()

配置:

command: "python train.py -k={mock_k} -b={mock_b} -td={training_data} -pt={param_type}"

3.argparse(不建议)

注意: 平台将使用gunicorn启动服务,gunicornargparse传参会出现冲突,如下图所示,不建议使用。 需要将parser.parse_args() 改写为 vars(ap.parse_args(args=[])) 解决冲突后,方可正常使用。

一般使用bool, int, str, float这些基本类型,更复杂的需求可以通过str传入,然后手动解析。bool类型的解析比较特殊,传入任何值都会被解析成True,传入空值时才为False

代码:

import argparse

def parse_argparse():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--mock_k', type=int, default=2)
parser.add_argument('--mock_b', type=int, default=1)
parser.add_argument('--training_data', type=str, default="data/get_started_train.csv")
parser.add_argument("--param_type", type=str, default="click")
# 以下写法可能会导致命令行传参冲突,不要使用
# args = parser.parse_args()
args = vars(ap.parse_args(args=[]))
return biz(args.mock_k, args.mock_b, args.training_data, args.param_type)

def train(mock_k, mock_b, training_data, param_type):
# training

if __name__ == '__main__':
train_argparse()

配置:

command: "python train.py -k={mock_k} -b={mock_b} -td={training_data} -pt={param_type}"

4.sys.argv(不建议)

支持最基础的sys.argv,较为简单,不支持一些复杂用法。不建议使用

代码:

def train_sys_argv():
import sys
mock_k, mock_b, training_data, param_type = 1, 2, "data/get_started_train.csv", "click"

if sys.argv[1]:
mock_k = sys.argv[1]
if sys.argv[2]:
mock_b = sys.argv[2]
if sys.argv[3]:
training_data = sys.argv[3]
if sys.argv[4]:
param_type = sys.argv[4]
return biz(mock_k, mock_b, training_data, param_type)

配置:

command: "python train.py {mock_k} {mock_b} {training_data} {param_type}"