了解 Transformers 是如何“思考”的

Python 投稿 11800 0 评论

了解 Transformers 是如何“思考”的

Thinking Like Transformers 这篇论文中提出了 transformer 类的计算框架,这个框架直接计算和模仿 Transformer 计算。使用 RASP 编程语言,使每个程序编译成一个特殊的 Transformer。

!pip install git+https://github.com/srush/RASPy

在说起语言本身前,让我们先看一个例子,看看用 Transformers 编码是什么样的。这是一些计算翻转的代码,即反向输入序列。代码本身用两个 Transformer 层应用 attention 和数学计算到达这个结果。

def flip(:
    length = (key(1 == query(1.value(1
    flip = (key(length - indices - 1 == query(indices.value(tokens
    return flip
flip(

文章目录

  • 第一部分:Transformers 作为代码

  • 第二部分:用 Transformers 编写程序

Transformers 作为代码

这个语言的核心单元是将一个序列转换成相同长度的另一个序列的序列操作。我后面将其称之为 transforms。

输入

tokens
tokens.input([5, 2, 4, 5, 2, 2]
indices
sop = indices
sop.input("goodbye"

前馈网络

tokens == "l"
model = tokens * 2 - 1
model.input([1, 2, 3, 5, 2]
model = tokens - 5 + indices
model.input([1, 2, 3, 5, 2]
(tokens == "l" | (indices == 1

where 提供了一个类似 if 功能的结构。

where((tokens == "h" | (tokens == "l", tokens, "q"

map 使我们可以定义自己的操作,例如一个字符串以 int 转换。(用户应谨慎使用可以使用的简单神经网络计算的操作)

atoi = tokens.map(lambda x: ord(x - ord('0'
atoi.input("31234"
def atoi(seq=tokens:
    return seq.map(lambda x: ord(x - ord('0' 

op = (atoi(where(tokens == "-", "0", tokens + 2
op.input("02-13"

注意力筛选器

key。

key(tokens

query 也一样

query(tokens

key 或 query 使用,他们会广播到基础序列的长度。

query(1
eq = (key(tokens == query(tokens
eq
  • 选择器的匹配位置偏移 1:

offset = (key(indices == query(indices - 1
offset
  • key 早于 query 的选择器:

before = key(indices < query(indices
before
  • key 晚于 query 的选择器:

after = key(indices > query(indices
after
before & eq

使用注意力机制

(请注意:在原始论文中,他们使用一个平均聚合操作并且展示了一个巧妙的结构,其中平均聚合能够代表总和计算。RASPy 默认情况下使用累加来使其简单化并避免碎片化。实际上,这意味着 raspy 可能低估了所需要的层数。基于平均值的模型可能需要这个层数的两倍

(key(tokens == query(tokens.value(1
length = (key(1 == query(1.value(1
length = length.name("length"
length

我们想要计算一个序列的相邻值的和,首先我们向前截断:

WINDOW=3
s1 = (key(indices >= query(indices - WINDOW + 1  
s1
s2 = (key(indices <= query(indices
s2
sel = s1 & s2
sel
sum2 = sel.value(tokens 
sum2.input([1,3,2,2,2]
def cumsum(seq=tokens:
    x = (before | (key(indices == query(indices.value(seq
    return x.name("cumsum"
cumsum(.input([3, 1, -2, 3, 1]

x = cumsum(length - indices
x.input([3, 2, 3, 5]

用 transformers 进行编程

例如: 给一个字符串 "19492+23919", 我们可以加载正确的输出吗?

挑战一:选择一个给定的索引

加载一个在索引 i 处全元素都有值的序列

def index(i, seq=tokens:
    x = (key(indices == query(i.value(seq
    return x.name("index"
index(1

挑战二:转换

i 位置将所有 token 移动到右侧。

def shift(i=1, default="_", seq=tokens:
    x = (key(indices == query(indices-i.value(seq, default
    return x.name("shift"
shift(2

挑战三:最小化

def minimum(seq=tokens:
    sel1 = before & (key(seq == query(seq
    sel2 = key(seq < query(seq
    less = (sel1 | sel2.value(1
    x = (key(less == query(0.value(seq
    return x.name("min"
minimum(([5,3,2,5,2]

挑战四:第一索引

def first(q, seq=tokens:
    return minimum(where(seq == q, indices, 99
first("l"

挑战五:右对齐

ralign(.inputs('xyz___' ='—xyz'" (2 层

def ralign(default="-", sop=tokens:
    c = (key(sop == query("_".value(1
    x = (key(indices + c == query(indices.value(sop, default
    return x.name("ralign"
ralign(("xyz__"

挑战六:分离

def split(v, i, sop=tokens:

    mid = (key(sop == query(v.value(indices
    if i == 0:
        x = ralign("0", where(indices < mid, sop, "_"
        return x
    else:
        x = where(indices > mid, sop, "0"
        return x
split("+", 1("xyz+zyr"
split("+", 0("xyz+zyr"

挑战七:滑动

def slide(match, seq=tokens:
    x = cumsum(match 
    y = ((key(x == query(x + 1 & (key(match == query(True.value(seq
    seq =  where(match, seq, y
    return seq.name("slide"
slide(tokens != "<".input("xxxh<<<l"

挑战八:增加

add(.input("683+345"
  1. 分成两部分。转制成整形。加入

“683+345” => [0, 0, 0, 9, 12, 8]

  1. 计算携带条款。三种可能性:1 个携带,0 不携带,< 也许有携带。

  2. 滑动进位系数

“00<100” => 001100"

  1. 完成加法

def add(sop=tokens:
    # 0 Parse and add
    x = atoi(split("+", 0, sop + atoi(split("+", 1, sop
    # 1 Check for carries 
    carry = shift(-1, "0", where(x > 9, "1", where(x == 9, "<", "0"
    # 2 In parallel, slide carries to their column                                         
    carries = atoi(slide(carry != "<", carry
    # 3 Add in carries.                                                                                  
    return (x + carries % 10
add(("683+345"
683 + 345
1028

完美搞定!

参考资料 & 文内链接:

  • 如果你对这个主题感兴趣想了解更多,请查看论文:Thinking Like Transformers

  • 以及了解更多 RASP 语言

  • 如果你对「形式语言和神经网络」(FLaNN 感兴趣或者有认识感兴趣的人,欢迎邀请他们加入我们的 线上社区!

  • 本篇博文,包含库、Notebook 和博文的内容

  • 本博客文章由 Sasha Rush 和 Gail Weiss 共同编写


译者:innovation64 (李洋

编程笔记 » 了解 Transformers 是如何“思考”的

赞同 (55) or 分享 (0)
游客 发表我的评论   换个身份
取消评论

表情
(0)个小伙伴在吐槽