注意
跳转到末尾下载完整示例代码。
如何在 OptunaHub 中实现您的剪枝器
OptunaHub 支持 Optuna 的剪枝器、采样器和可视化函数。本教程展示了如何在 OptunaHub 中实现并注册您自己的剪枝器。
通常,Optuna 提供 BasePruner
类来帮助您实现自己的采样器。您可以通过继承此类来实现自己的剪枝器。您需要安装 optuna
来实现自己的剪枝器。
$ pip install optuna
首先,导入 optuna 和其他必需的模块。
from __future__ import annotations
import optuna
from optuna.pruners import BasePruner
接下来,通过继承 BasePruner
类来定义您自己的剪枝器类。在本例中,我们实现了一个基于给定阈值停止目标函数的剪枝器。
class MyPruner(BasePruner): # type: ignore
def __init__(self, upper_threshold: float, n_warmup_steps: int) -> None:
self._upper_threshold = upper_threshold
self._n_warmup_steps = n_warmup_steps
# You need to implement `prune` method.
# This method returns true if it stops objective function, otherwise false.
# It stops the objective function if the intermediate value exceeds the threshold.
# Note that first `n_warmup_steps` steps are not pruned.
def prune(
self,
study: optuna.study.Study,
trial: optuna.trial.FrozenTrial,
) -> bool:
step = trial.last_step
if step is None:
return False
if step < self._n_warmup_steps:
return False
if trial.intermediate_values[step] > self._upper_threshold:
return True
return False
在此示例中,目标函数是一个简单的二次函数。它有 20 个变量,返回这些变量的平方和。
def objective(trial: optuna.trial.Trial) -> float:
s = 0.0
for step in range(20):
x = trial.suggest_float(f"x_{step}", -5, 5)
s += x**2
trial.report(s, step)
if trial.should_prune():
raise optuna.TrialPruned()
return s
此剪枝器可以像其他 Optuna 剪枝器一样使用。在以下示例中,我们创建一个研究并使用 MyPruner
类对其进行优化。
pruner = MyPruner(upper_threshold=100, n_warmup_steps=5)
study = optuna.create_study(pruner=pruner)
study.optimize(objective, n_trials=100)
在实现您自己的剪枝器后,您可以将其注册到 OptunaHub。有关如何将您的剪枝器注册到 OptunaHub 的信息,请参阅 如何将您的软件包注册到 OptunaHub。有关实现剪枝器的更多信息,请参阅 用户定义的剪枝器文档。
脚本总运行时间: (0 分钟 4.715 秒)