Skip to main content

taichi.ad#

class taichi.ad.FwdMode(loss, param, seed=None, clear_gradients=True)#
clear_seed(self)#
insert(self, func, mode_original)#
recover_kernels(self)#
class taichi.ad.Tape(loss=None, clear_gradients=True)#
grad(self)#
insert(self, func, args)#
taichi.ad.clear_all_gradients()#

Sets the gradients of all fields to zero.

taichi.ad.grad_for(primal)#

Generates a decorator to decorate primal’s customized gradient function.

See grad_replaced() for examples.

Parameters:

primal (Callable) – The primal function, must be decorated by grad_replaced().

Returns:

The decorator used to decorate customized gradient function.

Return type:

Callable

taichi.ad.grad_replaced(func)#

A decorator for python function to customize gradient with Taichi’s autodiff system, e.g. ti.ad.Tape() and kernel.grad().

This decorator forces Taichi’s autodiff system to use a user-defined gradient function for the decorated function. Its customized gradient must be decorated by grad_for().

Parameters:

fn (Callable) – The python function to be decorated.

Returns:

The decorated function.

Return type:

Callable

Example:

>>> @ti.kernel
>>> def multiply(a: ti.float32):
>>>     for I in ti.grouped(x):
>>>         y[I] = x[I] * a
>>>
>>> @ti.kernel
>>> def multiply_grad(a: ti.float32):
>>>     for I in ti.grouped(x):
>>>         x.grad[I] = y.grad[I] / a
>>>
>>> @ti.grad_replaced
>>> def foo(a):
>>>     multiply(a)
>>>
>>> @ti.grad_for(foo)
>>> def foo_grad(a):
>>>     multiply_grad(a)
taichi.ad.no_grad(func)#

A decorator for python function to skip gradient calculation within Taichi’s autodiff system, e.g. ti.ad.Tape() and kernel.grad(). This decorator forces Taichi’s autodiff system to use an empty gradient function for the decorated function.

Parameters:

fn (Callable) – The python function to be decorated.

Returns:

The decorated function.

Return type:

Callable

Example:

>>> @ti.kernel
>>> def multiply(a: ti.float32):
>>>     for I in ti.grouped(x):
>>>         y[I] = x[I] * a
>>>
>>> @ti.no_grad
>>> def foo(a):
>>>     multiply(a)