跳至内容

schola.sb3.utils.SB3A2CModel

类定义

class schola.sb3.utils.SB3A2CModel(policy, action_space)

基类: SB3PPOModel

参数

策略

类型: Policy
策略模型。

action_space

类型: Space
动作空间。

属性

T_destination

call_super_init

dump_patches

training

方法

__init__

__init__(policy, action_space)

初始化内部 Module 状态,由 nn.Module 和 ScriptModule 共享。

add_module

add_module(name, module)

将子模块添加到当前模块。

apply

apply(fn)

fn 递归地应用于每个子模块(由 .children() 返回)以及自身。

bfloat16

bfloat16()

将所有浮点参数和缓冲区转换为 bfloat16 数据类型。

buffers

buffers(recurse=True)

返回模块缓冲区迭代器。

children

children()

返回直接子模块的迭代器。

compile

compile(*args, **kwargs)

使用 torch.compile() 编译此模块的前向传播。

cpu

cpu()

将所有模型参数和缓冲区移动到 CPU。

cuda

cuda(device=None)

将所有模型参数和缓冲区移动到 GPU。

double

double()

将所有浮点参数和缓冲区转换为 double 数据类型。

eval

eval()

将模块设置为评估模式。

extra_repr

extra_repr()

返回模块的额外表示。

float

float()

将所有浮点参数和缓冲区转换为 float 数据类型。

forward

forward(*args)

定义每次调用时执行的计算。

get_buffer

get_buffer(target)

如果存在,则返回由 target 指定的缓冲区,否则抛出错误。

get_extra_state

get_extra_state()

返回模块 state_dict 中应包含的任何额外状态。

get_logits

get_logits(x)

get_parameter

get_parameter(target)

如果存在,则返回由 target 指定的参数,否则抛出错误。

get_submodule

get_submodule(target)

如果存在,则返回由 target 指定的子模块,否则抛出错误。

half

half()

将所有浮点参数和缓冲区转换为 half 数据类型。

ipu

ipu(device=None)

将所有模型参数和缓冲区移动到 IPU。

load_state_dict

load_state_dict(state_dict, strict=True, assign=False)

将参数和缓冲区从 state_dict 复制到此模块及其子模块。

modules

modules()

返回网络中所有模块的迭代器。

mtia

mtia(device=None)

将所有模型参数和缓冲区移动到 MTIA。

named_buffers

named_buffers(prefix='', recurse=True, remove_duplicate=True)

返回模块缓冲区迭代器,同时生成缓冲区名称和缓冲区本身。

named_children

named_children()

返回直接子模块的迭代器,同时生成模块名称和模块本身。

named_modules

named_modules(memo=None, prefix='', remove_duplicate=True)

返回网络中所有模块的迭代器,同时生成模块名称和模块本身。

named_parameters

named_parameters(prefix='', recurse=True, remove_duplicate=True)

返回模块参数迭代器,同时生成参数名称和参数本身。

parameters

parameters(recurse=True)

返回模块参数的迭代器。

register_backward_hook

register_backward_hook(hook)

在模块上注册一个后向钩子。

register_buffer

register_buffer(name, tensor, persistent=True)

向模块添加一个缓冲区。

register_forward_hook

register_forward_hook(hook, *, prepend=False, with_kwargs=False, always_call=False)

在模块上注册一个前向钩子。

register_forward_pre_hook

register_forward_pre_hook(hook, *, prepend=False, with_kwargs=False)

在模块上注册一个前向预钩子。

register_full_backward_hook

register_full_backward_hook(hook, prepend=False)

在模块上注册一个后向钩子。

register_full_backward_pre_hook

register_full_backward_pre_hook(hook, prepend=False)

在模块上注册一个后向预钩子。

register_load_state_dict_post_hook

register_load_state_dict_post_hook(hook)

注册一个在模块调用 load_state_dict() 后运行的后钩子。

register_load_state_dict_pre_hook

register_load_state_dict_pre_hook(hook)

注册一个在模块调用 load_state_dict() 前运行的预钩子。

register_module

register_module(name, module)

add_module() 的别名。

register_parameter

register_parameter(name, param)

向模块添加一个参数。

register_state_dict_post_hook

register_state_dict_post_hook(hook)

state_dict() 方法的后钩子注册。

register_state_dict_pre_hook

register_state_dict_pre_hook(hook)

state_dict() 方法的预钩子注册。

requires_grad_

requires_grad_(requires_grad=True)

更改 autograd 是否应记录此模块中参数的操作。

save_as_onnx

save_as_onnx(export_path, onnx_opset=17)

set_extra_state

set_extra_state(state)

设置加载的 state_dict 中包含的额外状态。

set_submodule

set_submodule(target, module)

如果存在,则设置由 target 指定的子模块,否则抛出错误。

share_memory

share_memory()

参见 torch.Tensor.share_memory_()

state_dict

state_dict(*args, destination=None, prefix='', keep_vars=False)

返回一个包含模块整个状态引用的字典。

to

to(*args, **kwargs)

移动和/或转换参数和缓冲区。

to_empty

to_empty(*, device, recurse=True)

将参数和缓冲区移动到指定设备,而不复制存储。

train

train(mode=True)

将模块设置为训练模式。

type

type(dst_type)

将所有参数和缓冲区转换为 dst_type

xpu

xpu(device=None)

将所有模型参数和缓冲区移动到 XPU。

zero_grad

zero_grad(set_to_none=True)

重置所有模型参数的梯度。

© . This site is unofficial and not affiliated with AMD.