学习根据sin预测cos的函数曲线,序列长度20,具体见代码1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65# coding=utf-8
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
import matplotlib.pyplot as plt
import numpy as np
from keras.models import Sequential
from keras.layers import LSTM, TimeDistributed, Dense
from keras.optimizers import Adam
BATCH_START = 0
TIME_STEPS = 20
BATCH_SIZE = 50
INPUT_DIM = 1
OUTPUT_SIZE = 1
UNITS = 20
LR = 0.008
def get_batch():
global BATCH_START
xs = np.arange(BATCH_START, BATCH_START + BATCH_SIZE * TIME_STEPS).reshape(BATCH_SIZE, TIME_STEPS) * np.pi / 180
#print xs[0, :]
seq = np.sin(xs)
res = np.cos(xs)
#print res[0, :]
BATCH_START += TIME_STEPS
#plt.plot(xs[0, :], res[0, :], 'r')#, xs[0, :], seq[0, :], 'b--')
#plt.plot(xs[0, :], res[0, :], 'r', label='cos')
#plt.plot(xs[0, :], seq[0, :], 'b', label='sin')
#plt.show()
#print seq[:, :, np.newaxis].shape
#print res[:, :, np.newaxis].shape
#print xs.shape
return [seq[:, :, np.newaxis], res[:, :, np.newaxis], xs]
#get_batch()
model = Sequential()
# build LSTM RNN
model.add(LSTM(
batch_input_shape=(BATCH_SIZE, TIME_STEPS, INPUT_DIM), # TIME_STEPS is input_length, INPUT_DIM is input_dim
#return_sequences is true, (BATCH_SIZE, TIME_STEPS, INPUT_DIM)--->(BATCH_SIZE, TIME_STEPS, UNITS);
#return_sequences is false, (BATCH_SIZE, TIME_STEPS, INPUT_DIM)-->(BATCH_SIZE, UNITS)
output_dim=UNITS,
return_sequences = True, # True output at all steps/all sequences, False output as last step/last output
stateful = True)) # True the final state or batch1 is feed into the initial state of the batch2
# add output layer
model.add(TimeDistributed(Dense(OUTPUT_SIZE))) #
model.compile(optimizer=Adam(LR), loss='mse')
for step in range(200):
x_batch, y_batch, xs = get_batch()
cost = model.train_on_batch(x_batch, y_batch)
predict = model.predict(x_batch, BATCH_SIZE)
plt.plot(xs[0, :], y_batch[0].flatten(), 'r', xs[0, :], predict.flatten()[:TIME_STEPS], 'b--')
plt.ylim(-1.5, 1.5)
plt.draw()
plt.pause(0.2)
if step % 10 == 0:
print('cost:', cost)
理论和代码-RNN-代码
坚持原创技术分享,您的支持将鼓励我继续创作!