简介

  • 在tensorflow v1.12中,新定义了一个修饰符函数tf.custom_gradients,用于封装自定义的函数-导数对。

示例

参考:stackoverflow

@tf.custom_gradient
def loop1(x,a):
    def grad(dy):
        return dy*3,dy*2
    n = tf.multiply(x,a)
    return n,grad