JAX (ライブラリ)とは? わかりやすく解説

Weblio 辞書 > 辞書・百科事典 > 百科事典 > JAX (ライブラリ)の意味・解説 

JAX (ライブラリ)

(Google JAX から転送)

出典: フリー百科事典『ウィキペディア(Wikipedia)』 (2025/02/02 06:59 UTC 版)

JAX
開発元 GoogleNVIDIA[1]
初版 2018年12月 (6年前) (2018-12)[2]
最新版
0.5.0 / 2025年1月18日 (15日前) (2025-01-18)[3]
リポジトリ jax - GitHub
プログラミング
言語
Python
対応OS WindowsmacOSLinux
プラットフォーム
種別 数値計算ライブラリ
ライセンス Apache License 2.0
公式サイト jax.readthedocs.io
テンプレートを表示

JAXは、高速な数値計算と大規模な機械学習のために設計されたPythonオープンソースのライブラリ[6]NumPy風の構文で書かれたPythonのソースコードCPUGPUAIアクセラレータ[7]コンパイルする実行時コンパイラ自動微分などを含む。

実行時コンパイラは、JAXからOpenXLAのXLAにコンパイルし、そこから先はハードウェア次第だが、多くのCPUとGPUはLLVMを経由してコンパイルされる[8]

基本的な使用方法

下記のソースコードのように、関数に @jit を付けることにより、その部分が実行時コンパイルされる。同一のソースコードで、CPUだけでなく、GPUやAIアクセラレータでも動作させることが可能である。詳細は後述するが、@jitの中に書けるのは普通のPythonのプログラムではなく、Pythonの構文を使用した純粋関数型言語である。

import jax.numpy as jnp
from jax import jit

@jit
def f(a, b):
    return a + b

x = jnp.array([1, 2, 3], dtype=jnp.float32)
print(f(x, x))

map を自動ベクトル化した vmap があり、a * 2 をあえて vmap を使用して書いた場合、下記のように書ける。SIMDを活用したプログラムにコンパイルされる。[9]

from jax import jit, vmap

@jit
def f(a):
    return vmap(lambda x: x * 2)(a)

Numbaとの違い

似たようなライブラリとしてNumbaがあるが、以下の違いがある。純粋関数型にすることにより色々な最適化がかかっている。関数型言語としての分類は、純粋、正格評価、型を明示する必要が無い静的型付けである。

相違点 JAX Numba
設計思想 純粋関数型。配列は不変で、形状(shape)はコンパイル時に静的に確定してないといけない。[10][11] 手続き型。配列の破壊的操作が可能。
if,match,while,for文 利用不可。代用関数が用意されている。 利用可能[12]
対象ハードウェア CPU・GPU・AIアクセラレータ全てで同一のソースコードで可能。 CPUとNVIDIA CUDAに対応しているが、全く異なるソースコードが必要。[13]
自動微分 対応[14] 非対応

純粋関数型であるため、乱数を使用する際に、下記のように、乱数生成のキーを明示的に作り直さないといけない。[15]

key, subkey = jax.random.split(key)
x = jax.random.normal(subkey)

配列を書き換える際は、手続き型では x[10] = 20 で良い場合も、 y = x.at[10].set(20) という構文になり、x と y は異なるインスタンスになる。ただし、以後 x を使用しない場合は、x に破壊的書き換えして y とする最適化が実行される。[16]

if文とmatch文

JAXではPythonのif文とmatch文は基本的にはそのままでは使用できない。下記が用意されている。

  • jax.lax.cond: Pythonのif文に対応するもので、例えば cond(x == 0, lambda: 10, lambda: 20) の様に使用し、True/Falseに応じてlambda式が実行される。JAXは正格評価の関数型言語のため、True/Falseが決まった後に分岐先の値を遅延評価するためにlambda式の中に入れる。[17]
  • jax.lax.switch: condを3択以上に出来るようにした物で、例えば switch(x, (lambda: 10, lambda: 20, lambda: 30)) の様に使用する。[18]
  • jax.lax.select: boolean配列に対してif文を使用する物で、例えば、xが配列の時 select(x == 0, jnp.array([1, 2]), jnp.array([3, 4])) の様に使用し、x == 0 が True/False に応じて各要素が振り分けられる。[19]
  • jax.lax.select_n: select を swtich の様に3択以上に出来るようにした物。[20]

while文とfor文

JAXではPythonのwhile文とfor文は基本的にはそのままでは使用できず、ループ回数が定数の場合でPythonのfor文をそのまま使用した場合は、ループアンロールされる。[21]

ループ構造を作るものとして下記が用意されている。

  • 関数型言語の fold 相当:jax.lax.fori_loop[22] と jax.lax.scan[23]
  • 関数型言語の unfold 相当:jax.lax.while_loop[24]
  • 関数型言語の map 相当:jax.vmap と jax.lax.map[25]

純粋関数型のため、scan, fori_loop, while_loop は全て前の計算結果を次に渡すという形となっている。

自動微分

jax.grad にて自動微分できる。例えば、最急降下法は下記で実装できる。init_x から始めて、fori_loop にて iter_count 回、計算を反復している。 カテゴリ

  • コモンズ
  • ウィキブックス
  • Portal:コンピュータ



  • 英和和英テキスト翻訳>> Weblio翻訳
    英語⇒日本語日本語⇒英語
      
    •  JAX (ライブラリ)のページへのリンク

    辞書ショートカット

    すべての辞書の索引

    「JAX (ライブラリ)」の関連用語

    1
    32% |||||


    3
    12% |||||


    5
    12% |||||

    6
    12% |||||

    7
    8% |||||

    8
    8% |||||

    9
    8% |||||

    10
    8% |||||

    JAX (ライブラリ)のお隣キーワード
    検索ランキング

       

    英語⇒日本語
    日本語⇒英語
       



    JAX (ライブラリ)のページの著作権
    Weblio 辞書 情報提供元は 参加元一覧 にて確認できます。

       
    ウィキペディアウィキペディア
    All text is available under the terms of the GNU Free Documentation License.
    この記事は、ウィキペディアのJAX (ライブラリ) (改訂履歴)の記事を複製、再配布したものにあたり、GNU Free Documentation Licenseというライセンスの下で提供されています。 Weblio辞書に掲載されているウィキペディアの記事も、全てGNU Free Documentation Licenseの元に提供されております。

    ©2025 GRAS Group, Inc.RSS