Thinking Like Transformers 这篇论文中提出了 transformer 类的计算框架,这个框架直接计算和模仿 Transformer 计算。使用 RASP 编程语言,使每个程序编译成一个特殊的 Transformer。
!pip install git+https://github.com/srush/RASPy
在说起语言本身前,让我们先看一个例子,看看用 Transformers 编码是什么样的。这是一些计算翻转的代码,即反向输入序列。代码本身用两个 Transformer 层应用 attention 和数学计算到达这个结果。
def flip(:
length = (key(1 == query(1.value(1
flip = (key(length - indices - 1 == query(indices.value(tokens
return flip
flip(
文章目录
第一部分:Transformers 作为代码
第二部分:用 Transformers 编写程序
Transformers 作为代码
这个语言的核心单元是将一个序列转换成相同长度的另一个序列的序列操作。我后面将其称之为 transforms。
输入
tokens
tokens.input([5, 2, 4, 5, 2, 2]
indices
sop = indices
sop.input("goodbye"
前馈网络
tokens == "l"
model = tokens * 2 - 1
model.input([1, 2, 3, 5, 2]
model = tokens - 5 + indices
model.input([1, 2, 3, 5, 2]
(tokens == "l" | (indices == 1
where 提供了一个类似 if
功能的结构。
where((tokens == "h" | (tokens == "l", tokens, "q"
map 使我们可以定义自己的操作,例如一个字符串以 int
转换。(用户应谨慎使用可以使用的简单神经网络计算的操作)
atoi = tokens.map(lambda x: ord(x - ord('0'
atoi.input("31234"
def atoi(seq=tokens:
return seq.map(lambda x: ord(x - ord('0'
op = (atoi(where(tokens == "-", "0", tokens + 2
op.input("02-13"
注意力筛选器
key。
key(tokens
query 也一样
query(tokens
key 或 query
使用,他们会广播到基础序列的长度。
query(1
eq = (key(tokens == query(tokens
eq
选择器的匹配位置偏移 1:
offset = (key(indices == query(indices - 1
offset
key 早于 query 的选择器:
before = key(indices < query(indices
before
key 晚于 query 的选择器:
after = key(indices > query(indices
after
before & eq
使用注意力机制
(请注意:在原始论文中,他们使用一个平均聚合操作并且展示了一个巧妙的结构,其中平均聚合能够代表总和计算。RASPy 默认情况下使用累加来使其简单化并避免碎片化。实际上,这意味着 raspy 可能低估了所需要的层数。基于平均值的模型可能需要这个层数的两倍
(key(tokens == query(tokens.value(1
length = (key(1 == query(1.value(1
length = length.name("length"
length
我们想要计算一个序列的相邻值的和,首先我们向前截断:
WINDOW=3
s1 = (key(indices >= query(indices - WINDOW + 1
s1
s2 = (key(indices <= query(indices
s2
sel = s1 & s2
sel
sum2 = sel.value(tokens
sum2.input([1,3,2,2,2]
def cumsum(seq=tokens:
x = (before | (key(indices == query(indices.value(seq
return x.name("cumsum"
cumsum(.input([3, 1, -2, 3, 1]
层
x = cumsum(length - indices
x.input([3, 2, 3, 5]
用 transformers 进行编程
例如: 给一个字符串 "19492+23919", 我们可以加载正确的输出吗?
挑战一:选择一个给定的索引
加载一个在索引 i
处全元素都有值的序列
def index(i, seq=tokens:
x = (key(indices == query(i.value(seq
return x.name("index"
index(1
挑战二:转换
i 位置将所有 token 移动到右侧。
def shift(i=1, default="_", seq=tokens:
x = (key(indices == query(indices-i.value(seq, default
return x.name("shift"
shift(2
挑战三:最小化
def minimum(seq=tokens:
sel1 = before & (key(seq == query(seq
sel2 = key(seq < query(seq
less = (sel1 | sel2.value(1
x = (key(less == query(0.value(seq
return x.name("min"
minimum(([5,3,2,5,2]
挑战四:第一索引
def first(q, seq=tokens:
return minimum(where(seq == q, indices, 99
first("l"
挑战五:右对齐
ralign(.inputs('xyz___' ='—xyz'" (2 层
def ralign(default="-", sop=tokens:
c = (key(sop == query("_".value(1
x = (key(indices + c == query(indices.value(sop, default
return x.name("ralign"
ralign(("xyz__"
挑战六:分离
def split(v, i, sop=tokens:
mid = (key(sop == query(v.value(indices
if i == 0:
x = ralign("0", where(indices < mid, sop, "_"
return x
else:
x = where(indices > mid, sop, "0"
return x
split("+", 1("xyz+zyr"
split("+", 0("xyz+zyr"
挑战七:滑动
def slide(match, seq=tokens:
x = cumsum(match
y = ((key(x == query(x + 1 & (key(match == query(True.value(seq
seq = where(match, seq, y
return seq.name("slide"
slide(tokens != "<".input("xxxh<<<l"
挑战八:增加
add(.input("683+345"
分成两部分。转制成整形。加入
“683+345” => [0, 0, 0, 9, 12, 8]
计算携带条款。三种可能性:1 个携带,0 不携带,< 也许有携带。
滑动进位系数
“00<100” => 001100"
完成加法
def add(sop=tokens:
# 0 Parse and add
x = atoi(split("+", 0, sop + atoi(split("+", 1, sop
# 1 Check for carries
carry = shift(-1, "0", where(x > 9, "1", where(x == 9, "<", "0"
# 2 In parallel, slide carries to their column
carries = atoi(slide(carry != "<", carry
# 3 Add in carries.
return (x + carries % 10
add(("683+345"
683 + 345
1028
完美搞定!
参考资料 & 文内链接:
如果你对这个主题感兴趣想了解更多,请查看论文:Thinking Like Transformers
以及了解更多 RASP 语言
如果你对「形式语言和神经网络」(FLaNN 感兴趣或者有认识感兴趣的人,欢迎邀请他们加入我们的 线上社区!
本篇博文,包含库、Notebook 和博文的内容
本博客文章由 Sasha Rush 和 Gail Weiss 共同编写
译者:innovation64 (李洋