如何在 OptunaHub 中实现您的剪枝器

OptunaHub 支持 Optuna 的剪枝器、采样器和可视化函数。本教程展示了如何在 OptunaHub 中实现并注册您自己的剪枝器。

通常,Optuna 提供 BasePruner 类来帮助您实现自己的采样器。您可以通过继承此类来实现自己的剪枝器。您需要安装 optuna 来实现自己的剪枝器。

$ pip install optuna

首先,导入 optuna 和其他必需的模块。

from __future__ import annotations

import optuna
from optuna.pruners import BasePruner

接下来,通过继承 BasePruner 类来定义您自己的剪枝器类。在本例中,我们实现了一个基于给定阈值停止目标函数的剪枝器。

class MyPruner(BasePruner):  # type: ignore
    def __init__(self, upper_threshold: float, n_warmup_steps: int) -> None:
        self._upper_threshold = upper_threshold
        self._n_warmup_steps = n_warmup_steps

    # You need to implement `prune` method.
    # This method returns true if it stops objective function, otherwise false.
    # It stops the objective function if the intermediate value exceeds the threshold.
    # Note that first `n_warmup_steps` steps are not pruned.
    def prune(
        self,
        study: optuna.study.Study,
        trial: optuna.trial.FrozenTrial,
    ) -> bool:
        step = trial.last_step
        if step is None:
            return False

        if step < self._n_warmup_steps:
            return False

        if trial.intermediate_values[step] > self._upper_threshold:
            return True

        return False

在此示例中,目标函数是一个简单的二次函数。它有 20 个变量,返回这些变量的平方和。

def objective(trial: optuna.trial.Trial) -> float:
    s = 0.0
    for step in range(20):
        x = trial.suggest_float(f"x_{step}", -5, 5)
        s += x**2
        trial.report(s, step)
        if trial.should_prune():
            raise optuna.TrialPruned()
    return s

此剪枝器可以像其他 Optuna 剪枝器一样使用。在以下示例中,我们创建一个研究并使用 MyPruner 类对其进行优化。

pruner = MyPruner(upper_threshold=100, n_warmup_steps=5)
study = optuna.create_study(pruner=pruner)
study.optimize(objective, n_trials=100)

在实现您自己的剪枝器后,您可以将其注册到 OptunaHub。有关如何将您的剪枝器注册到 OptunaHub 的信息,请参阅 如何将您的软件包注册到 OptunaHub。有关实现剪枝器的更多信息,请参阅 用户定义的剪枝器文档

脚本总运行时间: (0 分钟 4.715 秒)

由 Sphinx-Gallery 生成的图库