JAX_(ライブラリ)とは? わかりやすく解説

Weblio 辞書 > 辞書・百科事典 > 百科事典 > JAX_(ライブラリ)の意味・解説 

JAX (ライブラリ)

出典: フリー百科事典『ウィキペディア(Wikipedia)』 (2025/04/30 00:12 UTC 版)

JAX
開発元 GoogleNVIDIA[1]
初版 2018年12月 (6年前) (2018-12)[2]
最新版
0.5.0 / 2025年1月18日 (3か月前) (2025-01-18)[3]
リポジトリ jax - GitHub
プログラミング
言語
Python
対応OS WindowsmacOSLinux
プラットフォーム
種別 数値計算ライブラリ
ライセンス Apache License 2.0
公式サイト jax.readthedocs.io
テンプレートを表示

JAXは、高速な数値計算と大規模な機械学習のために設計されたPythonオープンソースのライブラリ[6]NumPy風の構文で書かれたPythonのソースコードCPUGPUAIアクセラレータ[7]コンパイルする実行時コンパイラ自動微分などを含む。

実行時コンパイラは、JAXからOpenXLAのXLAにコンパイルし、そこから先はハードウェア次第だが、多くのCPUとGPUはLLVMを経由してコンパイルされる[8]

基本的な使用方法

下記のソースコードのように、関数に @jit を付けることにより、その部分が実行時コンパイルされる。同一のソースコードで、CPUだけでなく、GPUやAIアクセラレータでも動作させることが可能である。詳細は後述するが、@jitの中に書けるのは普通のPythonのプログラムではなく、Pythonの構文を使用した純粋関数型言語である。

import jax.numpy as jnp
from jax import jit

@jit
def f(a, b):
    return a + b

x = jnp.array([1, 2, 3], dtype=jnp.float32)
print(f(x, x))

map を自動ベクトル化した vmap があり、a * 2 をあえて vmap を使用して書いた場合、下記のように書ける。SIMDを活用したプログラムにコンパイルされる。[9]

from jax import jit, vmap

@jit
def f(a):
    return vmap(lambda x: x * 2)(a)

Numbaとの違い

似たようなライブラリとしてNumbaがあるが、以下の違いがある。純粋関数型にすることにより色々な最適化がかかっている。関数型言語としての分類は、純粋、正格評価、型を明示する必要が無い静的型付けである。

相違点 JAX Numba
設計思想 純粋関数型。配列は不変で、形状(shape)はコンパイル時に静的に確定していなければならない。[10][11] 手続き型。配列の破壊的操作が可能。
if,match,while,for文 利用不可。代用関数が用意されている。 利用可能[12]
対象ハードウェア CPU・GPU・AIアクセラレータ全てで同一のソースコードで可能。 CPUとNVIDIA CUDAに対応しているが、全く異なるソースコードが必要。[13]
自動微分 対応[14] 非対応

純粋関数型であるため、乱数を使用する際に、下記のように、乱数生成のキーを明示的に作り直さなければならない。[15]

key, subkey = jax.random.split(key)
x = jax.random.normal(subkey)

配列を書き換える際は、手続き型では x[10] = 20 で良い場合も、 y = x.at[10].set(20) という構文になり、x と y は異なるインスタンスになる。ただし、以後 x を使用しない場合は、x に破壊的書き換えして y とする最適化が実行される。[16]

if文とmatch文

JAXではPythonのif文とmatch文は基本的にはそのままでは使用できない。下記が用意されている。

  • jax.lax.cond: Pythonのif文に対応するもので、例えば cond(x == 0, lambda: 10, lambda: 20) の様に使用し、True/Falseに応じてlambda式が実行される。JAXは正格評価の関数型言語のため、True/Falseが決まった後に分岐先の値を遅延評価するためにlambda式の中に入れる。[17]
  • jax.lax.switch: condを3択以上に出来るようにした物で、例えば switch(x, (lambda: 10, lambda: 20, lambda: 30)) の様に使用する。[18]
  • jax.lax.select: boolean配列に対してif文を使用する物で、例えば、xが配列の時 select(x == 0, jnp.array([1, 2]), jnp.array([3, 4])) の様に使用し、x == 0 が True/False に応じて各要素が振り分けられる。[19]
  • jax.lax.select_n: select を swtich の様に3択以上に出来るようにした物。[20]

while文とfor文

JAXではPythonのwhile文とfor文は基本的にはそのままでは使用できず、ループ回数が定数の場合でPythonのfor文をそのまま使用した場合は、ループアンロールされる。[21]

ループ構造を作るものとして下記が用意されている。

  • 関数型言語の fold 相当:jax.lax.fori_loop[22] と jax.lax.scan[23]
  • 関数型言語の unfold 相当:jax.lax.while_loop[24]
  • 関数型言語の map 相当:jax.vmap と jax.lax.map[25]

純粋関数型のため、scan, fori_loop, while_loop は全て前の計算結果を次に渡すという形となっている。

自動微分

jax.grad にて自動微分できる。例えば、最急降下法は下記で実装できる。init_x から始めて、fori_loop にて iter_count 回、計算を反復している。 カテゴリ

  • コモンズ
  • ウィキブックス
  • Portal:コンピュータ



  • 英和和英テキスト翻訳>> Weblio翻訳
    英語⇒日本語日本語⇒英語
      
    •  JAX_(ライブラリ)のページへのリンク

    辞書ショートカット

    すべての辞書の索引

    「JAX_(ライブラリ)」の関連用語

    JAX_(ライブラリ)のお隣キーワード
    検索ランキング

       

    英語⇒日本語
    日本語⇒英語
       



    JAX_(ライブラリ)のページの著作権
    Weblio 辞書 情報提供元は 参加元一覧 にて確認できます。

       
    ウィキペディアウィキペディア
    All text is available under the terms of the GNU Free Documentation License.
    この記事は、ウィキペディアのJAX (ライブラリ) (改訂履歴)の記事を複製、再配布したものにあたり、GNU Free Documentation Licenseというライセンスの下で提供されています。 Weblio辞書に掲載されているウィキペディアの記事も、全てGNU Free Documentation Licenseの元に提供されております。

    ©2025 GRAS Group, Inc.RSS