import tensorflow as tf
import numpy as np
train_data = tf.constant([
[[0, 1, 2, 3, 4],
[10, 11, 12, 13, 14],
[5, 6, 7, 8, 9]],
[[10, 11, 12, 13, 14],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[20, 21, 22, 23, 24],
[10, 11, 12, 13, 14],
[25, 26, 27, 28,