扩展 Sb3 和 Ray 的 launch.py
Schola 支持通过其他回调和日志记录来扩展 ray 和 sb3 的 launch.py 脚本。这是通过 Python 插件 完成的,这些插件可以自动发现扩展 schola 的新代码。插件必须继承适当的基类,并注册一个适当命名的入口点。然后 launch.py 脚本将自动 发现插件 并使用它来修改训练过程。以下是具体扩展 schola.scripts.sb3.launch 和 schola.scripts.ray.launch 脚本的步骤。
扩展 schola.scripts.sb3.launch
您可以通过其他回调、用于日志记录的 KVWriters 和命令行参数来扩展 schola.scripts.sb3.launch。以下是一个实现插件的示例,该插件添加了一个 CSV 日志记录器和一个每 N 个时间步记录一次的回调。
- 创建一个新类,继承自
Sb3LauncherExtension,并在相关时实现以下方法:get_extra_KVWriters、get_extra_callbacks和add_plugin_args_to_parser。
from schola.scripts.common import Sb3LauncherExtensionfrom dataclasses import dataclassfrom typing import Dict, Anyimport argparsefrom stable_baselines3.common.logger import KVWriters, CSVOutputFormatfrom stable_baselines3.common.callbacks import LogEveryNTimesteps
@dataclassclass ExampleSb3Extension(Sb3LauncherExtension): csv_save_path: str = "./output.csv" log_frequency: int = 1000
def get_extra_KVWriters(self): return [CSVOutputFormat(self.csv_save_path)]
def get_extra_callbacks(self): return [LogEveryNTimesteps(n_steps=log_frequency)]
@classmethod def add_plugin_args_to_parser(cls, parser: argparse.ArgumentParser): """ Add example logging arguments to the parser.
Parameters ---------- parser : argparse.ArgumentParser The parser to which the arguments will be added. """ group = parser.add_argument_group("CSV Logging") group.add_argument("--csv-save-path", type=str, help="The path to save the CSV file to") group.add_argument("--log-frequency", type=int, help="The frequency to log to the terminal")- 创建一个新的 Python 包,并在
schola.plugins.sb3.launch组中设置指向新类的入口点。
setup( ..., entry_points={ 'schola.plugins.sb3.launch': [ 'example_extension_name = example_plugin_name.example_extension_name:ExampleSb3Extension', ], }, ...,)扩展 schola.scripts.ray.launch
您可以通过其他回调和命令行参数来扩展 schola.scripts.ray.launch。以下是一个实现插件的示例,该插件添加了对 Wandb 日志记录的支持。
- 创建一个新类,继承自
RLLibLauncherExtension,并在相关时实现以下方法:get_extra_callbacks和add_plugin_args_to_parser。
from schola.scripts.common import RLLibLauncherExtensionfrom dataclasses import dataclassfrom typing import Any, Dict, Listimport argparsefrom ray.tune.integration.wandb import WandbLoggerCallback
@dataclassclass ExampleRayExtension(RLLibLauncherExtension): experiment_id: str = None
def get_extra_callbacks(self): return [WandbLoggerCallback(project=self.experiment_id)]
@classmethod def add_plugin_args_to_parser(cls, parser: argparse.ArgumentParser): """ Add example logging arguments to the parser.
Parameters ---------- parser : argparse.ArgumentParser The parser to which the arguments will be added. """ group = parser.add_argument_group("Wandb Logging") group.add_argument("--experiment-id", type=str, help="The experiment ID to log to")- 创建一个新的 Python 包,并在
schola.plugins.ray.launch组中设置指向新类的入口点。
setup( ..., entry_points={ 'schola.plugins.ray.launch': [ 'example_extension_name = example_plugin_name.example_extension_name:ExampleRayExtension', ], }, ...,)