对于tensorflow而言,代码前面加入下面代码可以解决问题
from tensorflow.compat.v1 import ConfigProto from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto() config.gpu_options.allow_growth = True session = InteractiveSession(config=config)
如果使用keras则可以用下面代码解决问题
import tensorflow as tf import numpy as np import keras config = tf.compat.v1.ConfigProto(allow_soft_placement=True) config.gpu_options.per_process_gpu_memory_fraction = 0.3 tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))
如果是tensorflow2.x需要用下面代码解决问题
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices(device_type='GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True)