解析JAX数组
文章目录
- JAX array
- the key characteristics
- Duck typing
- Code Examples for Key Concepts
- Array Usage Examples
- Key Differences from NumPy
- JAX 数组
- 主要特性
- 鸭子类型
- 关键概念说明
- 数组使用示例模式
- 与 NumPy 的主要区别
JAX array
the key characteristics
JAX is a Python library that provides automatic differentiation, JIT compilation, and array-oriented numerical computation similar to NumPy. It offers a NumPy-like interface for computation across CPU, GPU, and TPU hardware, delivering significantly better performance than standard NumPy through JIT compilation powered by Open XLA.
In particular, automatic differentiation in JAX enhances the efficiency of gradient computation, which is essential for machine learning and optimization tasks.
Installing JAX is straightforward. Use the following command in a terminal:
pip install jax
For GPU functionality (with CUDA 13), adjust the command as follows:
pip install -U "jax[cuda13]"
Duck typing
Duck typing is a fundamental characteristic of JAX. Duck typing is a programming concept commonly used in dynamic languages like Python, Ruby, and JavaScript. The concept originates from the saying: “If it walks like a duck and it quacks like a duck, then it must be a duck.”
In statically typed languages (such as Java or C++), the compiler verifies whether an object belongs to a specific class or implements a particular interface before allowing its use. In contrast, duck typing emphasizes what an object can do—its methods and properties—rather than its inheritance hierarchy or declared type. If an object possesses the expected methods or attributes, it can be utilized, even if its class does not formally inherit from a specific interface.
Keep in mind that JAX arrays are always immutable, unlike NumPy’s arrays.
Code Examples for Key Concepts
JAX provides several powerful features through a clean, functional API:
-
NumPy Compatibility: JAX’s
jax.numpymodule mirrors NumPy’s API, allowing seamless transition of NumPy code to JAX while enabling acceleration on hardware accelerators. -
JIT Compilation: The
@jax.jitdecorator orjax.jit()function transforms Python functions into optimized XLA operations. The first call triggers compilation (visible as slower execution), while subsequent calls use the cached, optimized version for significant speedups, especially with large arrays or complex operations. -
Automatic Differentiation: JAX’s
grad()function computes gradients automatically by tracing the computational graph. It supports forward-mode (jax.jacfwd), reverse-mode (jax.jacrev), and higher-order derivatives through function composition. -
Duck Typing: JAX functions work with any object that supports array-like operations (shape, dtype, indexing). This enables interoperability between JAX arrays, NumPy arrays, and other array-like objects without explicit type checking.
-
Immutability: Unlike NumPy arrays, JAX arrays are immutable. Instead of in-place modification, operations return new arrays. The
.at[]syntax provides a functional interface for updates that resembles in-place operations while maintaining immutability.
import jax
import jax.numpy as jnp
import numpy as np
# === 1. NumPy-like Interface Demo ===
print("=== NumPy-like Interface ===")
# Using JAX's NumPy API
x_jax = jnp.array([1., 2., 3., 4.])
x_numpy = np.array([1., 2., 3., 4.])
# Same basic operations
print(f"JAX sum: {jnp.sum(x_jax)}")
print(f"NumPy sum: {np.sum(x_numpy)}")
# === 2. JIT Compilation Speed Demo ===
print("
=== JIT Compilation ===")
# Define a computationally intensive function
def some_function(x):
for _ in range(1000): # Simulating complex computation
x = jnp.sin(x) * jnp.cos(x) + x**0.5
return x
# Execute before compilation
jit_some_function = jax.jit(some_function)
# Test performance difference
x = jnp.ones((1000, 1000))
%time result1 = some_function(x).block_until_ready() # Not compiled
%time result2 = jit_some_function(x).block_until_ready() # Compiled
# === 3. Automatic Differentiation Demo ===
print("
=== Automatic Differentiation ===")
def loss_function(params, x, y):
w, b = params
predictions = w * x + b
return jnp.mean((predictions - y)**2)
# Create some data
x_data = jnp.array([1., 2., 3., 4.])
y_data = jnp.array([2., 4., 6., 8.])
params = (jnp.array(1.5), jnp.array(0.5)) # w, b
# Use grad to compute gradients
grad_fn = jax.grad(loss_function, argnums=0)
gradients = grad_fn(params, x_data, y_data)
print(f"Gradients (dw, db): {gradients}")
# === 4. Duck Typing Demo ===
print("
=== Duck Typing Example ===")
def process_array(arr):
"""This function accepts any "array-like" object (supports shape and dtype attributes)"""
return f"Shape: {arr.shape}, Dtype: {arr.dtype}"
# Test different types of arrays
print(f"NumPy array: {process_array(np.array([1, 2, 3]))}")
print(f"JAX array: {process_array(jnp.array([1, 2, 3]))}")
# === 5. Immutability Demo ===
print("
=== Array Immutability ===")
# JAX arrays are immutable
jax_array = jnp.array([1, 2, 3, 4])
print(f"Original array: {jax_array}")
# Attempting modification will cause an error
try:
jax_array[0] = 10
except TypeError as e:
print(f"Error when trying to modify JAX array: {e}")
# Correct approach: create new array
new_array = jax_array.at[0].set(10)
print(f"New array after update: {new_array}")
print(f"Original array unchanged: {jax_array}")
# Compare with NumPy mutability
numpy_array = np.array([1, 2, 3, 4])
numpy_array[0] = 10
print(f"NumPy array after modification: {numpy_array}")
Array Usage Examples
import jax.numpy as jnp
print("=== JAX Array Basics ===")
# 1. Creating Arrays
arr1 = jnp.array([1, 2, 3, 4]) # From list
arr2 = jnp.zeros((3, 3)) # Zero matrix
arr3 = jnp.ones((2, 4)) # Ones matrix
arr4 = jnp.eye(5) # Identity matrix
arr5 = jnp.arange(0, 10, 2) # Range array
arr6 = jnp.linspace(0, 1, 5) # Linearly spaced array
arr7 = jnp.random.normal(key=jax.random.PRNGKey(0), shape=(3, 3)) # Random numbers
print("1. Array creation examples:")
print(f" zeros:
{arr2}")
print(f" arange: {arr5}")
print(f" random:
{arr7}")
# 2. Array Properties
print("
2. Array attributes:")
matrix = jnp.array([[1, 2, 3], [4, 5, 6]])
print(f" Shape: {matrix.shape}")
print(f" Dtype: {matrix.dtype}")
print(f" Size: {matrix.size}")
print(f" Ndim: {matrix.ndim}")
# 3. Indexing and Slicing
print("
3. Indexing and slicing:")
arr = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f" arr[0]: {arr[0]}")
print(f" arr[:, 1]: {arr[:, 1]}")
print(f" arr[1:, :2]:
{arr[1:, :2]}")
# 4. Mathematical Operations
print("
4. Mathematical operations:")
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])
print(f" a + b = {a + b}")
print(f" a * b = {a * b}")
print(f" jnp.dot(a, b) = {jnp.dot(a, b)}")
print(f" jnp.exp(a) = {jnp.exp(a)}")
print(f" jnp.sin(a) = {jnp.sin(a)}")
# 5. Broadcasting
print("
5. Broadcasting:")
matrix = jnp.ones((3, 3))
vector = jnp.array([1, 2, 3])
print(f" matrix + vector:
{matrix + vector}")
# 6. Array Reshaping
print("
6. Array reshaping:")
arr = jnp.arange(12)
print(f" Original shape: {arr.shape}")
print(f" Reshaped to (3, 4):
{arr.reshape(3, 4)}")
print(f" Flattened: {arr.reshape(-1)}")
# 7. Updating Immutable Arrays
print("
7. Updating immutable arrays:")
arr = jnp.array([10, 20, 30, 40, 50])
# Using .at[] for updates
arr_updated = arr.at[0].set(100)
arr_multi = arr.at[[1, 3]].set([200, 400])
arr_add = arr.at[2].add(100)
arr_mul = arr.at[4].multiply(2)
print(f" Original: {arr}")
print(f" Single update: {arr_updated}")
print(f" Multiple updates: {arr_multi}")
print(f" Add operation: {arr_add}")
print(f" Multiply operation: {arr_mul}")
# 8. Device Placement
print("
8. Device placement:")
x = jnp.array([1, 2, 3])
print(f" Device: {x.device()}")
# If explicit device placement is needed
# x_on_gpu = jax.device_put(x, jax.devices('gpu')[0]) # If GPU is available
JAX arrays follow these key usage patterns:
-
Creation: Arrays can be created from Python lists, using factory functions (
zeros(),ones(),eye()), sequences (arange(),linspace()), or random number generation. -
Properties: Arrays have standard properties like
shape,dtype,size, andndimthat mirror NumPy’s interface. -
Indexing: JAX supports NumPy-style indexing and slicing, including advanced indexing patterns.
-
Operations: Mathematical operations work element-wise with broadcasting support. JAX includes comprehensive mathematical functions in
jax.numpy. -
Broadcasting: Automatic dimension expansion follows NumPy’s broadcasting rules, allowing operations between arrays of different shapes.
-
Reshaping: Arrays can be reshaped, flattened, or transposed without copying data when possible.
-
Functional Updates: Due to immutability, updates use the
.at[]API which returns new arrays with specified modifications while preserving the original. -
Device Awareness: Arrays track their device placement (CPU/GPU/TPU) and can be explicitly moved between devices.
Key Differences from NumPy
While JAX provides a NumPy-like interface, these critical differences exist:
-
Immutable Arrays: JAX arrays cannot be modified in-place. All operations return new arrays. This enables functional transformations and ensures referential transparency.
-
Functional Paradigm: JAX encourages pure functions without side effects, which is essential for transformations like JIT, grad, and vmap to work correctly.
-
Random Number Generation: JAX uses explicit PRNG keys for stateless random number generation, unlike NumPy’s stateful global random state.
-
Device Execution: Operations can execute on accelerators (GPU/TPU) transparently, with data automatically transferred when needed.
-
Transformation Composability: JAX transformations (jit, grad, vmap) can be arbitrarily composed, enabling powerful combinations like computing batched gradients of JIT-compiled functions.
JAX 数组
主要特性
JAX 是一个 Python 库,提供自动微分、JIT 编译以及类似 NumPy 的面向数组数值计算。它为跨 CPU、GPU 和 TPU 硬件的计算提供了类似 NumPy 的接口,通过基于 Open XLA 的 JIT 编译实现比标准 NumPy 显著更优的性能。
特别是 JAX 中的自动微分提升了梯度计算的效率,这对于机器学习和优化任务至关重要。
安装 JAX 很简单。在终端中使用以下命令:
pip install jax
对于 GPU 功能(使用 CUDA 13),按以下方式调整命令:
pip install -U "jax[cuda13]"
鸭子类型
鸭子类型是 JAX 的基本特性之一。鸭子类型是动态语言(如 Python、Ruby、JavaScript)中常用的编程概念。这个概念源于一句俗语:“如果它走起来像鸭子,叫起来像鸭子,那么它一定是鸭子。”
在静态类型语言(如 Java 或 C++)中,编译器在使用对象之前会验证对象是否属于特定类或实现了特定接口。相比之下,鸭子类型强调对象能做什么——它的方法和属性——而不是它的继承层次结构或声明类型。如果一个对象拥有预期的方法或属性,它就可以被使用,即使它的类没有正式继承特定接口。
请注意,JAX 数组始终是不可变的,与 NumPy 数组不同。
关键概念说明
JAX 通过简洁的功能性 API 提供了几个强大的特性:
-
NumPy 兼容性:JAX 的
jax.numpy模块镜像了 NumPy 的 API,允许将 NumPy 代码无缝迁移到 JAX,同时在硬件加速器上实现加速。 -
JIT 编译:
@jax.jit装饰器或jax.jit()函数将 Python 函数转换为优化的 XLA 操作。首次调用会触发编译(表现为较慢的执行速度),而后续调用则使用缓存的优化版本,特别是在处理大型数组或复杂操作时能带来显著的速度提升。 -
自动微分:JAX 的
grad()函数通过追踪计算图自动计算梯度。它支持前向模式(jax.jacfwd)、反向模式(jax.jacrev),以及通过函数组合实现的高阶导数计算。 -
鸭子类型:JAX 函数适用于任何支持类数组操作(形状、数据类型、索引)的对象。这使得 JAX 数组、NumPy 数组和其他类数组对象之间无需显式类型检查即可互操作。
-
不可变性:与 NumPy 数组不同,JAX 数组是不可变的。操作会返回新数组而不是就地修改。
.at[]语法提供了一个功能性接口用于更新操作,既保持了不可变性,又类似于原地操作。
数组使用示例模式
JAX 数组遵循以下关键使用模式:
-
创建:数组可以从 Python 列表创建,或使用工厂函数(
zeros()、ones()、eye())、序列(arange()、linspace())或随机数生成器创建。 -
属性:数组具有标准属性,如
shape、dtype、size和ndim,这些属性与 NumPy 的接口一致。 -
索引:JAX 支持 NumPy 风格的索引和切片,包括高级索引模式。
-
操作:数学运算支持广播机制按元素工作。JAX 在
jax.numpy中包含全面的数学函数。 -
广播:自动维度扩展遵循 NumPy 的广播规则,允许不同形状数组之间的操作。
-
重塑:数组可以在可能的情况下重塑、展平或转置而无需复制数据。
-
功能性更新:由于不可变性,更新使用
.at[]API,返回指定修改的新数组同时保留原始数组。 -
设备感知:数组跟踪其设备位置(CPU/GPU/TPU),并且可以在设备间显式移动。
与 NumPy 的主要区别
虽然 JAX 提供了类似 NumPy 的接口,但存在以下关键差异:
-
不可变数组:JAX 数组不能就地修改。所有操作都返回新数组。这实现了功能转换并确保了引用透明性。
-
函数式范式:JAX 鼓励无副作用的纯函数,这对于 JIT、grad 和 vmap 等转换正常工作至关重要。
-
随机数生成:JAX 使用显式 PRNG 密钥进行无状态随机数生成,不同于 NumPy 的有状态全局随机状态。
-
设备执行:操作可以透明地在加速器(GPU/TPU)上执行,数据在需要时自动传输。
-
转换可组合性:JAX 转换(jit、grad、vmap)可以任意组合,实现了强大的功能组合,如计算 JIT 编译函数的批处理梯度。









