预测函数部分
test_gen = generator(list(yields['test'].keys()), yields['test'], 16)
X_test, y_test= next(test_gen)
IDS = list(yields['test'].keys())
sum = 0
for i in IDS:
sum += yields['test'][i]
avg = sum / len(IDS)
print("Average Test Yield is ", avg)
model = load_model('CNN_LSTM_AVG_1000')
print(len(X_test),len(y_test))
a, b = model.evaluate(X_test, y_test, batch_size=16)
y_pred = model.predict(X_test, batch_size=16, verbose=1)
generator函数处理数据
def generator(IDs, yields, batch_size, cutoff=None):
import numpy as np
import random
# Create empty arrays to get batch of features and labels
if cutoff != None:
batch_features = np.zeros((batch_size, cutoff, 1, 256, 10))
batch_yields = np.zeros((batch_size))
while True:
for i in range(batch_size):
# choose random index in features
index = random.choice(range(len(IDs)))
ID = IDs[index]
if np.sum(np.isnan(np.load('Data/PROCESSED_III/' + ID + '.npy'))) == 0:
batch_features[i, :, :, :, :] = np.load('Data/PROCESSED_III/' + ID + '.npy')[:cutoff, :, :, :]
# print('yes', ID)
batch_yields[i] = yields[ID]
else:
print('no', ID)
yield batch_features, batch_yields
else:
batch_features = np.zeros((batch_size, 14, 1, 32, 4))
batch_yields = np.zeros((batch_size))
while True:
for i in range(batch_size):
# choose random index in features
index = random.choice(range(len(IDs)))
ID = IDs[index]
if np.sum(np.isnan(np.load('Data/data_1/' + ID + '.npy'))) == 0:
batch_features[i, :, :, :, :] = np.load('Data/data_1/' + ID + '.npy')
# print('yes', ID)
batch_yields[i] = yields[ID]
else:
print('no', ID)
yield batch_features, batch_yields
报错
Traceback (most recent call last):
File "F:/Crop-Yield-Prediction-Using-CNN-LSTM/Crop-Yield-Prediction-Using-CNN-LSTM--Temp/Code/Train_HB.py", line 174, in <module>
a, b = model.evaluate(X_test, y_test, batch_size=16)
File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\keras\engine\training.py", line 66, in _method_wrapper
return method(self, *args, **kwargs)
File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1081, in evaluate
tmp_logs = test_function(iterator)
File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\def_function.py", line 580, in __call__
result = self._call(*args, **kwds)
File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\def_function.py", line 627, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\def_function.py", line 506, in _initialize
*args, **kwds))
File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\function.py", line 2446, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\function.py", line 2777, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\function.py", line 2667, in _create_graph_function
capture_by_value=self._capture_by_value),
File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\framework\func_graph.py", line 981, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\eager\def_function.py", line 441, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "D:\anaconda3\envs\DL\lib\site-packages\tensorflow\python\framework\func_graph.py", line 968, in wrapper
raise e.ag_error_metadata.to_exception(e)
NotImplementedError: in user code:
求求解答
相似问题