跳至内容

扩展 Sb3 和 Ray 的 launch.py

Schola 支持通过其他回调和日志记录来扩展 ray 和 sb3 的 launch.py 脚本。这是通过 Python 插件 完成的,这些插件可以自动发现扩展 schola 的新代码。插件必须继承适当的基类,并注册一个适当命名的入口点。然后 launch.py 脚本将自动 发现插件 并使用它来修改训练过程。以下是具体扩展 schola.scripts.sb3.launchschola.scripts.ray.launch 脚本的步骤。

扩展 schola.scripts.sb3.launch

您可以通过其他回调、用于日志记录的 KVWriters 和命令行参数来扩展 schola.scripts.sb3.launch。以下是一个实现插件的示例,该插件添加了一个 CSV 日志记录器和一个每 N 个时间步记录一次的回调。

  1. 创建一个新类,继承自 Sb3LauncherExtension,并在相关时实现以下方法: get_extra_KVWritersget_extra_callbacksadd_plugin_args_to_parser
from schola.scripts.common import Sb3LauncherExtension
from dataclasses import dataclass
from typing import Dict, Any
import argparse
from stable_baselines3.common.logger import KVWriters, CSVOutputFormat
from stable_baselines3.common.callbacks import LogEveryNTimesteps
@dataclass
class ExampleSb3Extension(Sb3LauncherExtension):
csv_save_path: str = "./output.csv"
log_frequency: int = 1000
def get_extra_KVWriters(self):
return [CSVOutputFormat(self.csv_save_path)]
def get_extra_callbacks(self):
return [LogEveryNTimesteps(n_steps=log_frequency)]
@classmethod
def add_plugin_args_to_parser(cls, parser: argparse.ArgumentParser):
"""
Add example logging arguments to the parser.
Parameters
----------
parser : argparse.ArgumentParser
The parser to which the arguments will be added.
"""
group = parser.add_argument_group("CSV Logging")
group.add_argument("--csv-save-path", type=str, help="The path to save the CSV file to")
group.add_argument("--log-frequency", type=int, help="The frequency to log to the terminal")
  1. 创建一个新的 Python 包,并在 schola.plugins.sb3.launch 组中设置指向新类的入口点。
setup(
...,
entry_points={
'schola.plugins.sb3.launch': [
'example_extension_name = example_plugin_name.example_extension_name:ExampleSb3Extension',
],
},
...,
)

扩展 schola.scripts.ray.launch

您可以通过其他回调和命令行参数来扩展 schola.scripts.ray.launch。以下是一个实现插件的示例,该插件添加了对 Wandb 日志记录的支持。

  1. 创建一个新类,继承自 RLLibLauncherExtension,并在相关时实现以下方法: get_extra_callbacksadd_plugin_args_to_parser
from schola.scripts.common import RLLibLauncherExtension
from dataclasses import dataclass
from typing import Any, Dict, List
import argparse
from ray.tune.integration.wandb import WandbLoggerCallback
@dataclass
class ExampleRayExtension(RLLibLauncherExtension):
experiment_id: str = None
def get_extra_callbacks(self):
return [WandbLoggerCallback(project=self.experiment_id)]
@classmethod
def add_plugin_args_to_parser(cls, parser: argparse.ArgumentParser):
"""
Add example logging arguments to the parser.
Parameters
----------
parser : argparse.ArgumentParser
The parser to which the arguments will be added.
"""
group = parser.add_argument_group("Wandb Logging")
group.add_argument("--experiment-id", type=str, help="The experiment ID to log to")
  1. 创建一个新的 Python 包,并在 schola.plugins.ray.launch 组中设置指向新类的入口点。
setup(
...,
entry_points={
'schola.plugins.ray.launch': [
'example_extension_name = example_plugin_name.example_extension_name:ExampleRayExtension',
],
},
...,
)
© . This site is unofficial and not affiliated with AMD.