logo
Published on

KAN さわってみた

Authors

KAN さわってみた

随時,更新予定

概要

  • この記事では,KAN の詳細(理論とか)については触れません.(理解していないので...)
  • どんな環境を用意し,どんなコードを書けば KAN を動かすことができるのか,というタスクに対するソリューションの1つを紹介しているのみです.
  • KAN の詳しい話を知りたい方は,参考のセクションに記載しているリンクが手助けになるかもしれません.

導入

最近,X で KAN という Neural Network(NN) に代わり得る新たな network を知った. なんでも,NN の学習でいうところの活性化関数を学習するとか. しかも,学習した活性化関数みたいなものを既存の関数(sinとかx^2とかexpとか)に当てはめることで,学習結果が解釈しやすいとか...

とりあえず,KAN の論文はこちら.(こちらが KAN の最初の提案論文かどうかは未調査です.)

https://arxiv.org/abs/2404.19756

ドキュメントがあったので,こちらを参考にコードを動かしてみよう.

https://kindxiaoming.github.io/pykan/index.html

KAN を試す

今回の環境は以下の感じです.

  • macOS Sonoma version 14.5
  • VSCode: 1.90.1
  • Python: 3.10.14
$ python -V
Python 3.10.14
  • Git 2.45.2
$ git -v
git version 2.45.2

手順・コードなど

# プロジェクトの dir を作成し,移動する
mkdir kan-tutorial
cd kan-tutorial

# 仮想環境を作成し,起動する
python -m venv .venv
source .venv/bin/activate

# 必要なライブラリをインストールする
pip install -U pip
pip install pykan notebook torch numpy scikit-learn matplotlib tqdm

# ipynb ファイルを作成する
touch main.ipynb
  • main.ipynb に以下のコードを記述する
main.ipynb
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from kan import *\n",
    "\n",
    "# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
    "model = KAN(width=[2, 5, 1], grid=5, k=3, seed=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create dataset f(x,y) = exp(sin(pi*x)+y^2)\n",
    "f = lambda x: torch.exp(torch.sin(torch.pi * x[:, [0]]) + x[:, [1]] ** 2)\n",
    "dataset = create_dataset(f, n_var=2)\n",
    "dataset[\"train_input\"].shape, dataset[\"train_label\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot KAN at initialization\n",
    "model(dataset[\"train_input\"])\n",
    "model.plot(beta=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train the model\n",
    "model.train(dataset, opt=\"LBFGS\", steps=20, lamb=0.01, lamb_entropy=10.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.prune()\n",
    "model.plot(mask=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.prune()\n",
    "model(dataset[\"train_input\"])\n",
    "model.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.train(dataset, opt=\"LBFGS\", steps=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = \"auto\"  # \"manual\"\n",
    "\n",
    "if mode == \"manual\":\n",
    "    # manual mode\n",
    "    model.fix_symbolic(0, 0, 0, \"sin\")\n",
    "    model.fix_symbolic(0, 1, 0, \"x^2\")\n",
    "    model.fix_symbolic(1, 0, 0, \"exp\")\n",
    "elif mode == \"auto\":\n",
    "    # automatic mode\n",
    "    lib = [\"x\", \"x^2\", \"x^3\", \"x^4\", \"exp\", \"log\", \"sqrt\", \"tanh\", \"sin\", \"abs\"]\n",
    "    model.auto_symbolic(lib=lib)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.symbolic_formula()[0][0]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
  • VSCode で main.ipynb を開き,コードを実行する
# VSCode を起動する
code .

# notebook で code を実行する

実行結果

学習したモデルのプロットはこちら.

KAN plot after training

なんかドキュメントの結果とは違ったなあ... 最終的に得られた式は,

1.0e1.0x22+1.0sin(3.14x1)0.01log{0.3(0.08x1)4+0.02sin(1.74x2+1.6)+0.07}0.021.0e^{1.0 x^2_2 + 1.0 \sin(3.14x_1)} - 0.01 \log{\{ 0.3(0.08 - x_1)^4 + 0.02 \sin(1.74 x_2 + 1.6) + 0.07 }\} - 0.02

であった(数式の書き方とか Katex の書き方がよくわからんので,変でもご容赦ください...). ちょっと省略できるところは省略したり,項を移動させたりしてもう少し綺麗にすると,

esin(3.14x1)+x220.01log{0.3(0.08x1)4+0.02sin(1.74x2+1.6)+0.07}0.02e^{\sin(3.14x_1) + x^2_2} - 0.01 \log{\{ 0.3(0.08 - x_1)^4 + 0.02 \sin(1.74 x_2 + 1.6) + 0.07 }\} - 0.02

とできた.作ったデータセットは,

f(x,y)=esin(πx)+y2f(x,y) = e^{\sin(\pi x) + y^2}

なので,xxx1x_1 に,yyx2x_2 に対応していると思われる. ただ,なんか

0.01log{0.3(0.08x1)4+0.02sin(1.74x2+1.6)+0.07}0.02- 0.01 \log{\{ 0.3(0.08 - x_1)^4 + 0.02 \sin(1.74 x_2 + 1.6) + 0.07 }\} - 0.02

という部分があるのが気になる.π\pi3.143.14 の差分を表しているのか・・・?(調べろ,自分)

感想

今回は,このドキュメントをもとにコードを動かしてみました.

いろいろ調べる中で,KAN がざっくり何をするのかはわかった気がしますが,詳細(KAN がなぜうまく機能するか?今回動かしたコードの意味は?)はまだまだ理解が追いついていません. X でも KAN のポストをちょいちょい見かけ,GPT に KAN を組み込んだ的なものも見かけたので,もしかしたら今後,KAN が注目されていくのかもしれません. 最近の Deep Learning を支える NN の仕組みが最適であるかどうかは,まだまだ研究の余地があると思うので,引き続き追っていきたいと思います.

参考・関連リンク

いろいろな KAN の実装があるみたい?なので,見つけた github や KAN に関連しそうな論文を雑に載せておきます.