schola.sb3.utils.SB3A2CModel
类定义
class schola.sb3.utils.SB3A2CModel(policy, action_space)基类: SB3PPOModel
参数
策略
类型: Policy
策略模型。
action_space
类型: Space
动作空间。
属性
T_destination
call_super_init
dump_patches
training
方法
__init__
__init__(policy, action_space)初始化内部 Module 状态,由 nn.Module 和 ScriptModule 共享。
add_module
add_module(name, module)将子模块添加到当前模块。
apply
apply(fn)将 fn 递归地应用于每个子模块(由 .children() 返回)以及自身。
bfloat16
bfloat16()将所有浮点参数和缓冲区转换为 bfloat16 数据类型。
buffers
buffers(recurse=True)返回模块缓冲区迭代器。
children
children()返回直接子模块的迭代器。
compile
compile(*args, **kwargs)使用 torch.compile() 编译此模块的前向传播。
cpu
cpu()将所有模型参数和缓冲区移动到 CPU。
cuda
cuda(device=None)将所有模型参数和缓冲区移动到 GPU。
double
double()将所有浮点参数和缓冲区转换为 double 数据类型。
eval
eval()将模块设置为评估模式。
extra_repr
extra_repr()返回模块的额外表示。
float
float()将所有浮点参数和缓冲区转换为 float 数据类型。
forward
forward(*args)定义每次调用时执行的计算。
get_buffer
get_buffer(target)如果存在,则返回由 target 指定的缓冲区,否则抛出错误。
get_extra_state
get_extra_state()返回模块 state_dict 中应包含的任何额外状态。
get_logits
get_logits(x)get_parameter
get_parameter(target)如果存在,则返回由 target 指定的参数,否则抛出错误。
get_submodule
get_submodule(target)如果存在,则返回由 target 指定的子模块,否则抛出错误。
half
half()将所有浮点参数和缓冲区转换为 half 数据类型。
ipu
ipu(device=None)将所有模型参数和缓冲区移动到 IPU。
load_state_dict
load_state_dict(state_dict, strict=True, assign=False)将参数和缓冲区从 state_dict 复制到此模块及其子模块。
modules
modules()返回网络中所有模块的迭代器。
mtia
mtia(device=None)将所有模型参数和缓冲区移动到 MTIA。
named_buffers
named_buffers(prefix='', recurse=True, remove_duplicate=True)返回模块缓冲区迭代器,同时生成缓冲区名称和缓冲区本身。
named_children
named_children()返回直接子模块的迭代器,同时生成模块名称和模块本身。
named_modules
named_modules(memo=None, prefix='', remove_duplicate=True)返回网络中所有模块的迭代器,同时生成模块名称和模块本身。
named_parameters
named_parameters(prefix='', recurse=True, remove_duplicate=True)返回模块参数迭代器,同时生成参数名称和参数本身。
parameters
parameters(recurse=True)返回模块参数的迭代器。
register_backward_hook
register_backward_hook(hook)在模块上注册一个后向钩子。
register_buffer
register_buffer(name, tensor, persistent=True)向模块添加一个缓冲区。
register_forward_hook
register_forward_hook(hook, *, prepend=False, with_kwargs=False, always_call=False)在模块上注册一个前向钩子。
register_forward_pre_hook
register_forward_pre_hook(hook, *, prepend=False, with_kwargs=False)在模块上注册一个前向预钩子。
register_full_backward_hook
register_full_backward_hook(hook, prepend=False)在模块上注册一个后向钩子。
register_full_backward_pre_hook
register_full_backward_pre_hook(hook, prepend=False)在模块上注册一个后向预钩子。
register_load_state_dict_post_hook
register_load_state_dict_post_hook(hook)注册一个在模块调用 load_state_dict() 后运行的后钩子。
register_load_state_dict_pre_hook
register_load_state_dict_pre_hook(hook)注册一个在模块调用 load_state_dict() 前运行的预钩子。
register_module
register_module(name, module)add_module() 的别名。
register_parameter
register_parameter(name, param)向模块添加一个参数。
register_state_dict_post_hook
register_state_dict_post_hook(hook)state_dict() 方法的后钩子注册。
register_state_dict_pre_hook
register_state_dict_pre_hook(hook)state_dict() 方法的预钩子注册。
requires_grad_
requires_grad_(requires_grad=True)更改 autograd 是否应记录此模块中参数的操作。
save_as_onnx
save_as_onnx(export_path, onnx_opset=17)set_extra_state
set_extra_state(state)设置加载的 state_dict 中包含的额外状态。
set_submodule
set_submodule(target, module)如果存在,则设置由 target 指定的子模块,否则抛出错误。
share_memory
share_memory()参见 torch.Tensor.share_memory_()。
state_dict
state_dict(*args, destination=None, prefix='', keep_vars=False)返回一个包含模块整个状态引用的字典。
to
to(*args, **kwargs)移动和/或转换参数和缓冲区。
to_empty
to_empty(*, device, recurse=True)将参数和缓冲区移动到指定设备,而不复制存储。
train
train(mode=True)将模块设置为训练模式。
type
type(dst_type)将所有参数和缓冲区转换为 dst_type。
xpu
xpu(device=None)将所有模型参数和缓冲区移动到 XPU。
zero_grad
zero_grad(set_to_none=True)重置所有模型参数的梯度。