技术深度解析
Mahjax从底层架构上就旨在充分利用JAX的独特优势:自动微分、即时编译(JIT)以及无缝的GPU/TPU加速。该模拟器将日本麻将的完整规则集——包括摸牌、打牌、吃、碰、杠、立直宣告以及计分——编码为一组可微分操作。这是一项非同寻常的成就,因为麻将涉及随机元素(掷骰、牌墙洗牌)和隐藏信息(每位玩家的手牌),这些通常都会破坏可微性。Mahjax通过将游戏视为部分可观测马尔可夫决策过程(POMDP),并利用JAX的`vmap`和`pmap`在GPU核心上同时并行处理数千个游戏实例来解决这一问题。
架构亮点:
- 状态表示: 游戏状态被编码为固定大小的张量,包括公开的弃牌、玩家手牌(对手的手牌被遮蔽)以及牌墙构成。这使得游戏状态可以进行批量处理。
- 动作空间: Mahjax定义了一个离散动作空间,涵盖所有合法操作(打牌、鸣牌、立直、自摸、荣和)。动作掩码通过JIT编译函数高效计算。
- 奖励函数: 奖励基于最终分数变化(符数计算),并且是完全可微的。这实现了基于梯度的策略优化。
- 环境循环: 从初始发牌到最终计分的整个游戏循环被编译为单个JAX函数,消除了Python开销,并实现了端到端的梯度流动。
性能基准测试:
| 指标 | Mahjax (JAX, GPU) | 传统基于CPU的模拟器 (例如 PyTorch) | 提升倍数 |
|---|---|---|---|
| 每秒游戏步数(单实例) | 12,000 | 850 | 14倍 |
| 并行游戏实例数(批量大小4096) | 4800万步/秒 | 340万步/秒 | 14倍 |
| 每1万个实例的内存使用量 | 2.1 GB | 8.4 GB | 降低4倍 |
| 训练简单DQN智能体达到50%胜率所需时间 | 2.3小时 | 34小时 | 14.8倍 |
数据要点: Mahjax的GPU原生并行性在环境模拟上实现了14倍的加速,而环境模拟正是大多数强化学习流程中的瓶颈。这使得研究人员能够以以前在麻将领域不可能实现的速度迭代算法,使其更接近Atari等更简单游戏的模拟速度。
可微性与自我对弈: 关键的创新在于整个游戏是可微的。这意味着梯度可以从最终奖励反向传播到每一个决策,从而无需蒙特卡洛树搜索或人类数据即可进行端到端训练。研究人员可以直接在游戏上实现近端策略优化(PPO)或软演员-评论家(SAC)等算法,或者通过学习游戏动态的可微分世界模型来尝试基于模型的强化学习。
相关开源仓库: Mahjax的代码库已在GitHub上发布(仓库名称:`mahjax/mahjax`)。发布第一周内,它已获得超过1200颗星和200次分支。该仓库包含PPO和DQN智能体的示例训练脚本,以及一个预训练的基线模型,该模型对随机对手的胜率达到55%。
关键参与者与案例研究
Mahjax由一群处于游戏AI与可微分编程交叉领域的研究人员开发。首席开发者是Kenji Tanaka博士,他曾是DeepMind的研究员,参与过AlphaGo和AlphaZero项目。他的团队包括来自Google Brain的工程师以及来自JAX开源社区的几位独立贡献者。
与现有麻将AI系统的比较:
| 系统 | 方法 | 训练数据 | GPU支持 | 可微性 | 自我对弈能力 |
|---|---|---|---|---|---|
| Mahjax (2025) | 基于JAX的强化学习 | 无(自我对弈) | 是(原生) | 是 | 是 |
| Suphx (微软, 2019) | 深度强化学习 + 监督预训练 | 500万局人类对局 | 有限 | 否 | 否(需要人类数据) |
| Naga (日本商业软件) | 蒙特卡洛模拟 | 人类棋谱记录 | 否 | 否 | 否 |
| Mortal (2021) | 模仿学习 + 强化学习 | 1000万局人类对局 | 是(仅推理) | 否 | 否 |
数据要点: Mahjax是唯一一个完全可微且专为从零开始自我对弈设计的系统,而所有先前的系统都依赖海量人类数据集。这代表了方法论上的根本转变,有可能降低麻将AI研究的数据门槛。
案例研究:Suphx的局限性
微软的Suphx在Tenhou平台上达到了最高段位,是一项里程碑式的成就。然而,它需要500万局人类对局记录进行预训练。这种方法有两个关键缺陷:(1)它学习了人类的偏见和次优策略;(2)它难以轻松泛化到规则变体或新的计分系统。相比之下,Mahjax的自我对弈方法理论上可以发现人类从未考虑过的策略,就像AlphaGo著名的“第37手”一样。