目录

简介

Let's now put the few concepts we have so far ---Tensor, GradientTape, Variable --- to build and train a simple model. This typically involves a few steps:

  • Define the model.
  • Define a loss function.
  • Obtain training data.
  • Run through the training data and use an "optimizer" to adjust the variables to fit the data.

代码

# coding: utf8

# ref: https://www.tensorflow.org/tutorials/eager/custom_training

import tensorflow as tf
import matplotlib.pyplot as plt
import random

tfe = tf.contrib.eager

tf.enable_eager_execution()


class Model(object):
    def __init__(self):
        self.W = tfe.Variable(random.randint(-5, 5)*1.0)
        self.b = tfe.Variable(random.randint(-5, 5)*1.0)

    def __call__(self, x):
        return self.W * x + self.b


def loss(predicted_y, desired_y):
    return tf.reduce_mean(tf.square(predicted_y - desired_y))


def train(model, inputs, outputs, learning_rate):
    with tf.GradientTape() as t:
        current_loss = loss(model(inputs), outputs)
    dW, db = t.gradient(current_loss, [model.W, model.b])
    model.W.assign_sub(learning_rate * dW)
    model.b.assign_sub(learning_rate * db)


TRUE_W = 3.0
TRUE_b = 2.0
NUM_EXAMPLES = 1000

inputs = tf.random_normal(shape=[NUM_EXAMPLES])
noise = tf.random_normal(shape=[NUM_EXAMPLES])
outputs = inputs * TRUE_W + TRUE_b + noise

model = Model()

# plt.scatter(inputs, outputs, c='b')
# plt.scatter(inputs, model(inputs), c='r')
# plt.show()


print('Current loss: '),
print(loss(model(inputs), outputs).numpy())
print(model.W.numpy(), model.b.numpy())


Ws, bs = [], []
epochs = range(300)
for epoch in epochs:
    Ws.append(model.W.numpy())
    bs.append(model.b.numpy())
    current_loss = loss(model(inputs), outputs)

    train(model, inputs, outputs, learning_rate=0.01)
    print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %
          (epoch, Ws[-1], bs[-1], current_loss))

# Let's plot it all
plt.plot(epochs, Ws, 'r',
         epochs, bs, 'b')
plt.plot([TRUE_W] * len(epochs), 'r--',
         [TRUE_b] * len(epochs), 'b--')
plt.legend(['W', 'b', 'true W', 'true_b'])
plt.show()

效果

...
Epoch 284: W=3.01 b=1.96, loss=0.91785
Epoch 285: W=3.01 b=1.96, loss=0.91784
Epoch 286: W=3.01 b=1.96, loss=0.91784
Epoch 287: W=3.01 b=1.96, loss=0.91783
Epoch 288: W=3.01 b=1.96, loss=0.91783
Epoch 289: W=3.01 b=1.96, loss=0.91782
Epoch 290: W=3.01 b=1.96, loss=0.91782
Epoch 291: W=3.01 b=1.96, loss=0.91782
Epoch 292: W=3.01 b=1.96, loss=0.91781
Epoch 293: W=3.01 b=1.96, loss=0.91781
Epoch 294: W=3.01 b=1.96, loss=0.91781
Epoch 295: W=3.01 b=1.96, loss=0.91780
Epoch 296: W=3.01 b=1.96, loss=0.91780
Epoch 297: W=3.01 b=1.96, loss=0.91779
Epoch 298: W=3.01 b=1.96, loss=0.91779
Epoch 299: W=3.01 b=1.97, loss=0.91779

avatar