|
| 1 | +#!/bin/bash |
| 2 | +set -e |
| 3 | +#=============================================================================== |
| 4 | +# 1. 配置区域 - 用户只需修改这里 |
| 5 | +#=============================================================================== |
| 6 | +SUFFIX="ajet_deep_finance" # 实验后缀,影响所有日志和实验名称 |
| 7 | +PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 |
| 8 | + |
| 9 | +# OpenJudge 模型配置 |
| 10 | +OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 |
| 11 | +RM_LLM='qwen-max' # RM Gallery 评分模型 |
| 12 | +JUDGE_CONCURRENCY=10 |
| 13 | + |
| 14 | +# 奖励权重配置 |
| 15 | +RM_WEIGHT=0.4 |
| 16 | +CITATION_AUDIT_WEIGHT=0.2 |
| 17 | +REPORT_RESOLUTION_WEIGHT=0.2 |
| 18 | +TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 |
| 19 | + |
| 20 | +# 训练参数配置 |
| 21 | +NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 |
| 22 | +TRAIN_BATCH_SIZE=32 # 训练batchsize |
| 23 | +NUM_STEPS=6 # 每个样本step轮数 |
| 24 | +DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 |
| 25 | + |
| 26 | +# 主目录 |
| 27 | + |
| 28 | +NNODES=${WORLD_SIZE} |
| 29 | + |
| 30 | +# 涉密的配置(API_KEY以及模型、数据位置)从.env读取 |
| 31 | +cd ${AJET_ROOT} |
| 32 | +source .venv/bin/activate |
| 33 | + |
| 34 | +# API密钥配置 - 从 .env 文件加载 |
| 35 | +ENV_FILE="${AJET_ROOT}/.env" |
| 36 | +if [ -f "$ENV_FILE" ]; then |
| 37 | + set -a |
| 38 | + source "$ENV_FILE" |
| 39 | + set +a |
| 40 | + echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" |
| 41 | +else |
| 42 | + echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" |
| 43 | +fi |
| 44 | + |
| 45 | +#=============================================================================== |
| 46 | +# 2. 动态生成配置文件 (从yaml template生成yaml) |
| 47 | +#=============================================================================== |
| 48 | +# 修改:配置文件生成路径,现在动态生成到 yaml 目录下 |
| 49 | +CONFIG_TEMPLATE="tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml" |
| 50 | +CONFIG_FILE="${AJET_ROOT}/tutorial/example_deep_finance/yaml/${SUFFIX}.yaml" |
| 51 | +mkdir -p $(dirname ${CONFIG_FILE}) |
| 52 | + |
| 53 | +sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ |
| 54 | + -e "s|{{PREFIX}}|${PREFIX}|g" \ |
| 55 | + -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ |
| 56 | + -e "s|{{NNODES}}|${NNODES}|g" \ |
| 57 | + -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ |
| 58 | + -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ |
| 59 | + -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ |
| 60 | + -e "s|{{RM_LLM}}|${RM_LLM}|g" \ |
| 61 | + -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ |
| 62 | + -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ |
| 63 | + -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ |
| 64 | + -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ |
| 65 | + -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ |
| 66 | + -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ |
| 67 | + -e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \ |
| 68 | + -e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \ |
| 69 | + -e "s|{{ENV_SERVICE_URL}}|${ENV_SERVICE_URL}|g" \ |
| 70 | + -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ |
| 71 | + -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ |
| 72 | + -e "s|{{CKPT_SAVE_PATH}}|${CKPT_SAVE_PATH}|g" \ |
| 73 | + ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} |
| 74 | + |
| 75 | +echo "配置文件已生成: ${CONFIG_FILE}" |
| 76 | +echo "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" |
| 77 | + |
| 78 | +#=============================================================================== |
| 79 | +# 3. 环境配置 |
| 80 | +#=============================================================================== |
| 81 | +# MongoDB 缓存配置 |
| 82 | +CACHE_TYPE="mongodb" |
| 83 | +MONGO_URI="mongodb://${ADDR}:27117/" |
| 84 | +MONGO_DB_NAME="finworld_cache" |
| 85 | +MONGO_COLLECTION_NAME="tool_cache" |
| 86 | +export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME |
| 87 | + |
| 88 | +# DeepFinance MCP 配置 |
| 89 | +DEEPFINANCE_MCP_CONFIG="${AJET_ROOT}/tutorial/example_deep_finance/config/mcp_finance_tool_generated.json" |
| 90 | + |
| 91 | +# 动态生成 MCP 配置文件 |
| 92 | +mkdir -p $(dirname ${DEEPFINANCE_MCP_CONFIG}) |
| 93 | +cat > ${DEEPFINANCE_MCP_CONFIG} << EOF |
| 94 | +{ |
| 95 | + "mcpServers": { |
| 96 | + "flowllm": { |
| 97 | + "transport": "sse", |
| 98 | + "url": "http://${ADDR}:${MCP_PORT}/sse", |
| 99 | + "timeout": 600, |
| 100 | + "sse_read_timeout": 1200 |
| 101 | + } |
| 102 | + } |
| 103 | +} |
| 104 | +EOF |
| 105 | +export DEEPFINANCE_MCP_CONFIG DEEPFINANCE_TOOL_RESULT_MAX_CHARS |
| 106 | + |
| 107 | +# 其他服务配置 |
| 108 | +HF_ENDPOINT="https://hf-mirror.com" |
| 109 | +ES_HOSTS="http://11.160.132.46:8200" |
| 110 | +export HF_ENDPOINT ES_HOSTS |
| 111 | + |
| 112 | +# log 文件位置 |
| 113 | +CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") |
| 114 | +LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" |
| 115 | +MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" |
| 116 | +ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" |
| 117 | +TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" |
| 118 | + |
| 119 | +# 多机训练参数配置 |
| 120 | +GPUS_PER_NODE=8 |
| 121 | +EXPECTED_WORKERS=$WORLD_SIZE |
| 122 | + |
| 123 | + |
| 124 | +#=============================================================================== |
| 125 | +# 4. 工具函数 以及 NCCL 配置(固定) |
| 126 | +#=============================================================================== |
| 127 | +print_green() { |
| 128 | + echo -e "\033[32m$1\033[0m" |
| 129 | +} |
| 130 | + |
| 131 | +log() { |
| 132 | + echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" |
| 133 | +} |
| 134 | + |
| 135 | +check_workers() { |
| 136 | + local status_output=$(ray status 2>/dev/null) |
| 137 | + if [ -z "$status_output" ]; then echo 0; return; fi |
| 138 | + local node_count=$(echo "$status_output" | grep -E "^[[:space:]]*1[[:space:]]+node_" | wc -l) |
| 139 | + if [ "$node_count" -gt 0 ]; then echo $node_count; return; fi |
| 140 | + echo $(echo "$status_output" | grep -o "node_[0-9a-f]\+" | sort -u | wc -l) |
| 141 | +} |
| 142 | + |
| 143 | +check_gpu_resources() { |
| 144 | + gpu_count=$(ray status 2>/dev/null | grep -A 10 "Resources" | grep "GPU" | awk '{print $1}' | cut -d'/' -f2) |
| 145 | + if [ -z "$gpu_count" ]; then echo 0; else printf "%.0f" "$gpu_count"; fi |
| 146 | +} |
| 147 | + |
| 148 | + |
| 149 | +export NCCL_TIMEOUT=1800 |
| 150 | +export NCCL_DEBUG=WARN |
| 151 | +export NCCL_IB_TIMEOUT=23 |
| 152 | +export NCCL_ASYNC_ERROR_HANDLING=1 |
| 153 | + |
| 154 | +#=============================================================================== |
| 155 | +# 5. 工具envservice 环境变量 |
| 156 | +#=============================================================================== |
| 157 | + |
| 158 | +export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" |
| 159 | +export RAY_CLUSTER_MODE="multi_node" |
| 160 | + |
| 161 | + |
| 162 | +#=============================================================================== |
| 163 | +# 6. 主流程 |
| 164 | +#=============================================================================== |
| 165 | +log "开始多机多卡训练: ${SUFFIX}" |
| 166 | +log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" |
| 167 | +mkdir -p ${LOG_DIR} |
| 168 | +mkdir -p $(dirname ${CONFIG_FILE}) |
| 169 | + |
| 170 | +#=============================================================================== |
| 171 | +# 6.1 Master 节点启动流程 |
| 172 | +#=============================================================================== |
| 173 | +# 启动训练任务(最核心) |
| 174 | +python ajet/launcher.py \ |
| 175 | + --conf ${CONFIG_FILE} \ |
| 176 | + --backbone="debug" \ |
| 177 | + 2>&1 | tee ${TRAIN_LOG} |
0 commit comments