元编程
Taichi 为元编程提供了基础架构。 在 Taichi 中,元编程具有很多好处:
- 有利于维度自适应代码的开发,例如即适用于 2 维也适用于3 维情况的物理模拟。
- 将运行耗时移动到编译耗时,以提高运行时的性能。
- 简化 Taichi 标准库的开发.
note
Taichi kernels are lazily instantiated and large amounts of computation can be executed at compile-time. 即使没有模板参数,Taichi 中的每一个内核也都是模板内核。
模版元编程
通过使用 ti.template()
作为参数的类型提示,Taichi field 或者一个 Python 对象可以作为参数被传递到 kernel 当中。 模板编程还可以将内核重用于不同形状的场。
@ti.kernel
def copy_1D(x: ti.template(), y: ti.template()):
for i in x:
y[i] = x[i]
a = ti.field(ti.f32, 4)
b = ti.field(ti.f32, 4)
c = ti.field(ti.f32, 12)
d = ti.field(ti.f32, 12)
# Pass field a and b as arguments of the kernel `copy_1D`:
copy_1D(a, b)
# Reuse the kernel for field c and d:
copy_1D(c, d)
note
If a template parameter is not a Taichi object, it cannot be reassigned inside Taichi kernel.
note
The template parameters are inlined into the generated kernel after compilation.
使用组合索引实现维度自适应的编程
Taichi 提供了 ti.group
语法,支持将循环下标集合成 ti.Vector
。 它使得独立于维度的编程成为可能,即代码能够自适应 于不同维度的场景:
@ti.kernel
def copy_1D(x: ti.template(), y: ti.template()):
for i in x:
y[i] = x[i]
@ti.kernel
def copy_2d(x: ti.template(), y: ti.template()):
for i, j in x:
y[i, j] = x[i, j]
@ti.kernel
def copy_3d(x: ti.template(), y: ti.template()):
for i, j, k in x:
y[i, j, k] = x[i, j, k]
# Kernels listed above can be unified into one kernel using `ti.grouped`:
@ti.kernel
def copy(x: ti.template(), y: ti.template()):
for I in ti.grouped(y):
# I 是一个维度和 y 相同的向量
# 如果 y 是 0 维的,则 I = ti.Vector([]),就相当于`None`被用于 x[I]
# 如果 y 是 1 维的,则I = ti.Vector([i])
# 如果 y 是 2 维的,则 I = ti.Vector([i, j])
# 如果 y 是 3 维的,则 I = ti.Vector([i, j, k])
# ...
x[I] = y[I]
场的元数据
无论在 Taichi 作用域还是在 Python 作用域中,都可以使用 field.dtype
和 field.shape
来访问 field 的数据类型和尺寸这两个属性。
x = ti.field(dtype=ti.f32, shape=(3, 3))
# 在 Python 作用域中打印场的元数据
print("Field dimensionality is ", x.shape)
print("Field data type is ", x.dtype)
# 在 Taichi 作用域中打印场的元数据
@ti.kernel
def print_field_metadata(x: ti.template()):
print("Field dimensionality is ", len(x.shape))
for i in ti.static(range(len(x.shape))):
print("Size along dimension ", i, "is", x.shape[i])
ti.static_print("Field data type is ", x.dtype)
note
For sparse fields, the full domain shape will be returned.
矩阵 & 向量的元数据
对于矩阵,matrix.m
和 matrix.n
分别返回列数和行数。 Taichi 把向量看作只有一列的矩阵,vector.n
表示的是向量的元素个数。
@ti.kernel
def foo():
matrix = ti.Matrix([[1, 2], [3, 4], [5, 6]])
print(matrix.n) # 行数:3
print(matrix.m) # 列数:2
vector = ti.Vector([7, 8, 9])
print(vector.n) # 元素个数:3
print(vector.m) # 对于向量来说恒为1
编译时评估
使用编译时评估可以将部分计算量移到内核实例化时进行。 这有助于编译器实现最优化以减少运行时的计算开销。
静态作用域
ti.static
是一个接收一个参数的函数。 它提示编译器在编译时评估参数。 ti.static
参数的作用域被称为静态作用域。
编译时分支
- 使用
ti.static
对编译时分支展开(对于熟悉 C++17 的人来说,这类似于 if constexpr。):
enable_projection = True
@ti.kernel
def static():
if ti.static(enable_projection): # 没有运行时开销
x[0] = 1
note
One of the two branches of the static if
will be discarded after compilation.
循环展开
- 使用
ti.static
强制进行循环展开:
@ti.kernel
def func():
for i in ti.static(range(4)):
print(i)
# 上述的代码片段相当于:
print(0)
print(1)
print(2)
print(3)
note
Before v1.4.0, indices for accessing Taichi matrices/vectors must be compile-time constants. Therefore, if the indices come from a loop, the loop must be unrolled:
# Here we declare a field containing 3 vectors. 每一个向量包含8个元素。
x = ti.Vector.field(8, ti.f32, shape=3)
@ti.kernel
def reset():
for i in x:
for j in ti.static(range(x.n)):
# The inner loop must be unrolled since j is an index for accessing a vector.
x[i][j] = 0
Starting from v1.4.0, indices for accessing Taichi matrices/vectors can be runtime variables. Therefore, the loop above is no longer required to be unrolled. That said, unrolling it will still help you reduce runtime overhead.
Compile-time recursion of ti.func
编译时递归函数是一个在编译时递归内嵌的函数。 在编译时评估是否满足递归的条件。
你可以参考 编译时分支 和 template 来写编译时递归函数。
例如, sum_from_one_to
是一个编译时递归函数, 用来计算从 1
到 n 的
的数字之和。
@ti.func
def sum_from_one_to(n: ti.template()) -> ti.i32:
ret = 0
if ti.static(n > 0):
ret = n + sum_from_one_to(n - 1)
return ret
@ti.kernel
def sum_from_one_to_ten():
print(sum_from_one_to(10)) # prints 55
WARNING
When the recursion is too deep, it is not recommended to use compile-time recursion because deeper compile-time recursion expands to longer code during compilation, resulting in increased compilation time.