jax-js 将 JAX 带入浏览器
jax-js 将 JAX 高性能数值计算与自动微分能力引入浏览器,通过生成 Wasm/WebGPU 内核绕开 JS 性能瓶颈,实现原生级速度。
近日,开发者 Eric Zhang 正式发布了 jax-js,这是一个专为 Web 平台设计的纯 JavaScript 机器学习框架。该项目的核心愿景是将 Google DeepMind 旗下广受欢迎的 JAX 框架的强大能力引入浏览器环境,旨在让前端开发者也能在客户端直接利用高性能的数值计算与自动微分功能,从而为构建无需依赖后端的实时交互式 AI 应用开辟了新的道路。您可以在 Eric Zhang 的博客中阅读到关于此项目的详细构思与心路历程。
长期以来,JavaScript 在需要密集数值计算的机器学习领域一直处于相对劣势。其根本原因在于,JavaScript 引擎的即时编译(JIT)优化并非为处理紧密的数值循环而设计,甚至缺乏原生的快速整数类型支持,这导致其计算性能难以满足现代深度学习模型的需求。然而,随着 WebAssembly 和 WebGPU 两项关键 Web 技术的成熟与普及,游戏规则正在被改写。jax-js 正是抓住了这一技术浪潮,它通过将开发者编写的 JavaScript 逻辑即时编译成高效的 WebAssembly 和 WebGPU 内核,使得复杂的数值计算程序能够以接近原生的速度在浏览器中执行,从而巧妙地绕过了传统 JavaScript 解释器带来的性能瓶颈。
在编程模型与 API 设计上,jax-js 高度还原了 JAX 框架的设计哲学。它完整支持程序追踪与即时编译(JIT),能够将开发者用 JavaScript 编写的计算图动态地转化为可在 GPU 上执行的着色器指令。尽管由于 JavaScript 语言本身的限制,它无法像 Python 中的 JAX 那样通过运算符重载(如 *)来实现优雅的数学表达式,而必须采用 .mul() 这类显式的方法调用,但其整体 API 接口与 NumPy 和 JAX 保持了高度一致,极大地降低了开发者的学习与迁移成本。此外,为了应对 JavaScript 缺乏引用计数和确定性析构函数所带来的内存管理挑战,jax-js 创造性地借鉴了 Rust 语言的所有权语义,通过一套名为 .ref 的系统来精细地管理张量内存的生命周期,有效防止了内存泄漏。
从功能特性来看,jax-js 完整保留了 JAX 框架的精髓,包括用于计算梯度的 自动微分、用于批量数据处理的 向量化变换,以及用于性能优化的 内核融合。项目作者展示了一个极具说服力的案例:在浏览器中,仅使用 jax-js 框架,便能从零开始训练一个识别 MNIST 手写数字的神经网络,并在数秒内达到超过 99% 的准确率。另一个更具实践意义的演示是,它能实时处理一部包含 18 万字的文学巨著,通过集成 CLIP 嵌入模型,实现毫秒级的语义搜索,这充分展现了其在客户端进行复杂 AI 推理的潜力。
性能表现是 jax-js 最引人注目的亮点之一。根据基准测试,在搭载 M4 Pro 芯片的设备上,其矩阵乘法的计算能力超过了 3 TFLOPs。在特定的基准测试场景中,其性能甚至超越了更为成熟的 TensorFlow.js 和 ONNX Runtime Web 等框架。这一卓越表现主要归功于其先进的编译器架构。与许多依赖预编译静态库的框架不同,jax-js 的编译器能够根据输入张量的具体形状动态地进行优化并生成高度定制化的计算内核,从而实现极致的性能调优。
深入其技术架构,jax-js 将整个框架清晰地划分为两个部分:负责程序追踪、自动微分和计算图构建的前端,以及负责执行优化后计算内核的后端。其自动微分系统的实现参考了 Tinygrad 的简洁设计,基于数学上的对偶变换原理。这种设计使得框架开发者在实现基础运算的一阶导数规则后,系统便能自动推导出任意高阶的导数,架构十分优雅。这种清晰的分离设计不仅保证了代码的可维护性,也为未来实现更复杂的内核融合与跨平台优化提供了极高的灵活性。
目前,jax-js 项目已在 GitHub 上开源。尽管它在某些方面,如卷积运算的深度优化和 WebAssembly 多线程的全面支持上,仍有进一步的提升空间,但它已经有力地证明了在浏览器环境中构建一个完整、高性能机器学习生态系统的可行性。对于所有希望探索前沿 Web 技术,并致力于在不依赖后端服务器的情况下,打造实时、交互式人工智能应用的开发者而言,jax-js 的出现无疑开启了一扇充满可能性的新大门。
原文链接: jax-js:将高性能机器学习引入浏览器的新范式





