
def main():
parser = set_parser()
options = parser.parse_args()
params = update_params_based_on_args(options)
selected_model = params["model"]["model_name"]
if selected_model == "2D_UNET_base":
model = UNetModule
elif selected_model == "SWIN":
model = SWINModule
train(params, options.gpus, options.mode, options.checkpoint, model)set_parser()是一个函数,用于设置和返回一个argparse.ArgumentParser对象
parser.parse_args()方法来解析命令行参数并将结果存储在options变量中
def update_params_based_on_args(options):
config_p = os.path.join("configurations", options.config_path)
params = load_config(config_p)
if options.name != "":
print(params["experiment"]["name"])
params["experiment"]["name"] = options.name
if options.epochs is not None:
params["train"]["max_epochs"] = options.epochs
if options.batch_size is not None:
params["train"]["batch_size"] = options.batch_size
if options.num_workers is not None:
params["train"]["n_workers"] = options.num_workers
if options.input_path != "":
params["dataset"]["data_root"] = options.input_path
if options.output_path != "":
params["experiment"]["experiment_folder"] = options.output_path
if options.region_to_predict != "":
params["predict"]["region_to_predict"] = options.region_to_predict
if options.year_to_predict != "":
params["predict"]["year_to_predict"] = options.year_to_predict
if options.submission_out_dir != "":
params["predict"]["submission_out_dir"] = options.submission_out_dir
return params具有强度输出和概率输出的模型的基本模块。需要验证和预测实现的抽象类。
BaseModule的类,它继承自LightningModule和ABC。
因为继承自LightningModule,要重写training_step、validation_step、predict_step、configure_optimizers方法,详见后续。
ABC(Abstract Base Class)是一个用于定义抽象基类的元类。抽象基类是不能被实例化的类,它主要用于定义接口和共享方法的规范。通过继承抽象基类,子类需要实现抽象基类中定义的抽象方法,以满足基类的接口规范。抽象基类可以提供一种约束,确保子类的一致性和可替换性。
if self.probabilistic:
# Store bucket means (but not as model parameter) as the channel dimension of the data
self.register_buffer(
"bucket_means",
torch.tensor(self.buckets.means).view(1, -1, 1, 1, 1),
)
self.bucket_means: torch.Tensor如果损失函数是概率型的(probabilistic=True),则代码会使用self.register_buffer方法注册一个缓冲区(buffer)bucket_means,用于存储损失函数的桶均值。这里使用torch.tensor将桶均值转换成张量,并通过view方法对其进行形状变换,以便后续使用。需要注意的是,注册的缓冲区不会作为模型的参数进行优化。
if model_params["upsample"] == "bilinear":
self.upsample = BilinearUpsample(42, 252, self.forecast_length)
elif model_params["upsample"] == "nearest":
self.upsample = NearestUpsample(42, 252, self.forecast_length)
elif model_params["upsample"] == "ninasr":
self.upsample = NinaSRUpsample(
42, 252, self.forecast_length, self.num_classes
)
elif model_params["upsample"] == "edsr":
self.upsample = EDSRUpsample(
42, 252, self.forecast_length, self.num_classes
)
else:
self.upsample = None根据model_params["upsample"]的值选择相应的上采样方法对象赋值给self.upsample。根据代码片段提供的信息,上采样方法可以是BilinearUpsample、NearestUpsample、NinaSRUpsample或EDSRUpsample。如果model_params["upsample"]的值不在这些选项中,self.upsample将被设置为None。
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()abstractmethod是一个装饰器,用于定义抽象方法。抽象方法是在抽象基类中声明但没有实现的方法,它只有方法的声明部分,没有具体的方法体。抽象方法必须在子类中被重写实现,否则子类也会成为抽象类。通过使用abstractmethod装饰器,可以明确地表示某个方法是抽象方法。
在这段代码中,forward方法被定义为抽象方法,即没有具体的实现。抽象方法使用abstractmethod装饰器进行修饰,表示它是一个需要在子类中被重写实现的方法。子类必须提供forward方法的具体实现,以满足抽象基类的接口规范。
def augment_batch(self, batch):
"""Apply augmentation on training batches (flips and 90-degrees rotation)"""
# TODO - Change to data loader
if not self.transform:
return batch
input, label, metadata = batch
angle = random.choice([-90, 0, 90, 180])
transformations = [
v2.RandomHorizontalFlip(),
v2.RandomVerticalFlip(),
v2.RandomRotation([angle, angle]),
]
t = random.choice(transformations)
input = t(input).contiguous()
label = t(label).contiguous()
# Transform masks
metadata["input"]["mask"] = t(metadata["input"]["mask"])
metadata["target"]["mask"] = t(metadata["target"]["mask"])
# Transform static data if any
if self.static_data:
metadata["input"]["topo"] = t(metadata["input"]["topo"])
metadata["target"]["topo"] = t(metadata["target"]["topo"])
metadata["input"]["lat-long"] = t(metadata["input"]["lat-long"])
metadata["target"]["lat-long"] = t(metadata["target"]["lat-long"])
return input, label, metadataaugment_batch方法接受一个batch参数,表示训练批次数据。该方法的作用是对训练批次数据进行增强操作,包括翻转和旋转。增强操作可以提高模型的鲁棒性和泛化能力,使其能够更好地适应不同的输入样本。
在当前的实现中,首先判断是否需要进行数据增强操作,如果self.transform为False,则直接返回原始的批次数据。否则,从批次数据中获取输入、标签和元数据。然后,随机选择一个角度(-90度、0度、90度或180度),并定义一些变换操作,包括随机水平翻转、随机垂直翻转和随机旋转。接下来,从变换操作中随机选择一个变换t,并将其应用于输入、标签和元数据的对应部分。其中,输入和标签通过调用变换对象的__call__方法进行转换,并使用contiguous方法保证数据的连续性。对于元数据中的掩码(mask)数据和静态数据(如果有的话),也需要进行相应的变换操作。最后,返回经过增强操作后的输入、标签和元数据。
def add_static(self, input, metadata):
lat_long = (
metadata["input"]["lat-long"]
.unsqueeze(2)
.repeat(1, 1, self.history_length, 1, 1)
)
topo = (
metadata["input"]["topo"]
.unsqueeze(2)
.repeat(1, 1, self.history_length, 1, 1)
)
input = torch.cat([input, lat_long, topo], dim=1)
return inputadd_static方法接受两个参数,input表示输入数据,metadata表示元数据。该方法的作用是将静态数据(lat_long和topo)添加到输入数据中。
首先,代码从metadata中获取了lat_long和topo数据。这些数据可能是二维张量,表示地理坐标和地形信息。然后,通过使用unsqueeze方法在适当的维度上添加一个维度,以便进行重复复制。使用repeat方法将lat_long和topo在相应的维度上进行重复,以匹配输入数据的形状。接下来,使用torch.cat方法将输入数据、lat_long和topo在维度1上进行连接,将它们合并成一个更大的输入张量。最后,返回合并后的输入数据。
def training_step(self, batch):
batch = self.augment_batch(batch)
input, label, metadata = batch
# Add static data to input if required
if self.static_data:
input = self.add_static(input, metadata)
input = self.transform_input(input)
prediction = self.forward(input)
if self.upsample:
prediction = self.upsample(prediction)
mask = metadata["target"]["mask"]
loss = self.loss_fn(prediction, label, mask)
self.log("train/loss", loss, sync_dist=True)
return losstraining_step方法接受一个batch参数,表示训练批次数据。该方法的作用是执行一次训练步骤,包括数据增强、添加静态数据、输入转换、模型前向传播、上采样、计算损失和记录训练损失。
首先,代码调用augment_batch方法对批次数据进行增强操作,得到增强后的批次数据。然后,从增强后的批次数据中获取输入、标签和元数据。接下来,根据是否需要添加静态数据的设置,判断是否需要将静态数据添加到输入中。如果需要添加静态数据,调用add_static方法将静态数据添加到输入数据中,得到添加了静态数据的输入。然后,调用transform_input方法对输入数据进行转换,得到转换后的输入数据。接着,调用forward方法对转换后的输入数据进行模型的前向传播,得到预测结果。如果定义了上采样方法(self.upsample不为None),则对预测结果进行上采样操作。接下来,从元数据中获取目标数据的掩码(mask)。然后,使用损失函数self.loss_fn计算预测结果与标签之间的损失,传入预测结果、标签和掩码作为参数。最后,使用self.log方法记录训练损失,并返回损失值。
def validation_step(self, batch, batch_idx) -> ValidationOutput:
input, label, metadata = batch
# Add static data to input if required
if self.static_data:
input = self.add_static(input, metadata)
input = self.transform_input(input)
prediction = self.forward(input)
if self.upsample:
prediction = self.upsample(prediction)
mask = metadata["target"]["mask"]
loss = self.loss_fn(prediction, label, mask)
self.log("val/loss", loss, sync_dist=True)
if self.probabilistic:
# If no softmax, apply as it is required for the metrics (i.e. CRPS)
if self.activation == "none":
prediction = nn.functional.softmax(prediction, dim=1)
probabilities = prediction
intensity = self.integrate(prediction)
else:
probabilities = None
intensity = prediction
return ValidationOutput(intensity=intensity, probabilities=probabilities)validation_step方法接受两个参数,batch表示验证批次数据,batch_idx表示批次索引。该方法的作用是执行一次验证步骤,包括添加静态数据、输入转换、模型前向传播、上采样、计算损失、记录验证损失和返回验证输出。
首先,代码从验证批次数据中获取输入、标签和元数据。接下来,根据是否需要添加静态数据的设置,判断是否需要将静态数据添加到输入中。如果需要添加静态数据,调用add_static方法将静态数据添加到输入数据中,得到添加了静态数据的输入。然后,调用transform_input方法对输入数据进行转换,得到转换后的输入数据。接着,调用forward方法对转换后的输入数据进行模型的前向传播,得到预测结果。如果定义了上采样方法(self.upsample不为None),则对预测结果进行上采样操作。接下来,从元数据中获取目标数据的掩码(mask)。然后,使用损失函数self.loss_fn计算预测结果与标签之间的损失,传入预测结果、标签和掩码作为参数。接着,使用self.log方法记录验证损失,并传入"val/loss"作为日志名称,loss作为损失值,并设置sync_dist=True以确保在分布式训练中同步日志。如果模型的损失函数是概率型的(self.probabilistic=True),则进行一些额外的操作。首先,如果激活函数是"none"(即没有使用激活函数),则将预测结果进行 softmax 操作,因为一些指标(如 CRPS)需要概率分布的预测结果。然后,将预测结果作为概率分布probabilities,并将预测结果进行积分得到intensity。最后,返回一个ValidationOutput对象,包含intensity和probabilities。
def predict_step(self, batch, batch_idx=None) -> torch.Tensor:
input, _, metadata = batch
# Add static data to input if required
if self.static_data:
input = self.add_static(input, metadata)
input = self.transform_input(input)
prediction = self.forward(input)
if self.upsample:
prediction = self.upsample(prediction)
if self.probabilistic:
# If no softmax, apply as it to sum 1
if self.activation == "none":
prediction = nn.functional.softmax(prediction, dim=1)
probabilities = prediction
intensity = self.integrate(prediction)
else:
probabilities = None
intensity = prediction
intensity = intensity[:, :, : self.forecast_length, :, :]
return intensity首先,代码从预测批次数据中获取输入数据和元数据,忽略了标签数据(_)。接下来,根据是否需要添加静态数据的设置,判断是否需要将静态数据添加到输入中。如果需要添加静态数据,调用add_static方法将静态数据添加到输入数据中,得到添加了静态数据的输入。然后,调用transform_input方法对输入数据进行转换,得到转换后的输入数据。接着,调用forward方法对转换后的输入数据进行模型的前向传播,得到预测结果。如果定义了上采样方法(self.upsample不为None),则对预测结果进行上采样操作。如果模型的损失函数是概率型的(self.probabilistic=True),则进行一些额外的操作。首先,如果激活函数是"none"(即没有使用激活函数),则将预测结果进行 softmax 操作,以确保预测结果的和为1。然后,将预测结果作为概率分布probabilities,并将预测结果进行积分得到intensity。最后,根据预测长度截取intensity中的相应部分,并返回截取后的intensity作为预测结果。
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)
return optimizer使用了torch.optim.AdamW优化器类来创建一个AdamW优化器对象。AdamW是Adam优化器的一种变体,它在优化过程中引入了权重衰减(weight decay)的正则化项,有助于控制模型的复杂度并提高泛化能力。在创建优化器对象时,传入了两个参数。self.parameters()表示要优化的模型参数,即模型中所有需要进行梯度更新的参数。lr=self.lr和weight_decay=self.weight_decay分别指定了学习率和权重衰减的数值,这些数值是在模型初始化时从参数中获取的。最后,将创建的优化器对象返回。
交叉熵和均方误差计算,对应概率输出和强度输出。
callbacks文件夹应该放回调代码就可以了,不知道为什么把metrics代码也放这里。

用于在PyTorch Lightning框架中记录和计算各种指标(metrics)的值
def __init__(self, num_leadtimes, probabilistic, buckets, logging):
super().__init__()
self.num_leadtimes = num_leadtimes
self.probabilistic = probabilistic
if buckets != "none":
self.buckets = BUCKET_CONSTANTS[buckets]
else:
self.buckets = None
self.logging = logging
self.thresholds = [0.2, 1, 5, 10, 15]接收参数num_leadtimes(leading time steps)、probabilistic(是否概率性指标)、buckets(用于概率性指标的桶大小)、logging(指标记录的方式)。
from dataclasses import dataclass
from typing import List
@dataclass
class Bucket:
idx: int
mean: float
max: float
weight: float
@dataclass
class BucketConstants:
buckets: List[Bucket]
means: List[float]
weights: List[float]
boundaries: List[float]
ranges: List[float]
num_buckets: int
# Custom buckets used for classification when using mm/h
_buckets_mmh = [
Bucket(idx=0, mean=0, max=0.08, weight=0.5107),
Bucket(idx=1, mean=0.12, max=0.16, weight=0.6014),
Bucket(idx=2, mean=0.2, max=0.25, weight=0.627),
Bucket(idx=3, mean=0.32, max=0.4, weight=0.6295),
Bucket(idx=4, mean=0.51, max=0.63, weight=0.631),
Bucket(idx=5, mean=0.81, max=1, weight=0.6359),
Bucket(idx=6, mean=1.3, max=1.6, weight=0.6472),
Bucket(idx=7, mean=2.0, max=2.5, weight=0.6667),
Bucket(idx=8, mean=3.25, max=4, weight=0.6901),
Bucket(idx=9, mean=5.15, max=6.3, weight=0.7298),
Bucket(idx=10, mean=8.1, max=10, weight=0.7823),
Bucket(idx=11, mean=13, max=16, weight=0.8428),
Bucket(idx=12, mean=20.5, max=25, weight=0.9084),
Bucket(idx=13, mean=32.5, max=40, weight=0.9617),
Bucket(
idx=14, mean=45, max=128, weight=1.0
), # Max is 128 as defined by preprocessing
]
def getBucketObject(buckets_list):
return BucketConstants(
buckets=buckets_list,
means=[b.mean for b in buckets_list],
weights=[b.weight for b in buckets_list],
boundaries=[b.max for b in buckets_list[:-1]],
ranges=[
buckets_list[i].max - buckets_list[i - 1].max
if i > 0
else buckets_list[i].max
for i in range(len(buckets_list))
],
num_buckets=len(buckets_list),
)
BUCKET_CONSTANTS = {
"mmh": getBucketObject(_buckets_mmh),
"test": getBucketObject(_buckets_test),
"w4c23_1": getBucketObject(_buckets_w4c23_1),
"w4c23_2": getBucketObject(_buckets_w4c23_2),
}创建和管理不同的桶(Bucket)对象,并将其存储在BUCKET_CONSTANTS字典中。通过调用getBucketObject函数,可以根据桶列表获取相应的BucketConstants对象。这样做的目的是为了方便地创建和使用不同的桶,并将其关联到特定的名称,以供其他代码使用。
dataclasses 模块提供了一个装饰器 @dataclass,用于方便地创建和操作数据类(data class),它自动为类的属性生成相应的方法(如构造函数、属性访问方法、比较方法等),使得创建和操作数据对象更加简洁和方便。
from dataclasses import dataclass
@dataclass
class Person:
name: str
age: int
occupation: str # # Code for checking if a metric can be optimized
# check_forward_full_state_property(
# metrics.MeanSquaredError,
# input_args={
# "prediction": torch.Tensor([0.5, 2.5]),
# "label": torch.Tensor([1.0, 2.0]),
# "mask": torch.zeros([2], dtype=bool),
# },
# )被注释掉的代码是用于检查一个指标是否可以进行优化的示例代码。它使用torchmetrics库中的check_forward_full_state_property函数来检查均方误差(MeanSquaredError)指标是否可以进行优化。函数的输入参数为一个字典,包含了预测值(prediction)、标签值(label)和掩码(mask)。通过检查指标的前向计算是否可以成功执行,可以确保指标的正确性和可用性。
def _threshold_str(self, threshold):
"""Remove .0 and change . by -"""
return f"{threshold:g}".replace(".", "-")该段代码定义了一个名为"_threshold_str"的私有方法,用于处理阈值(threshold)的字符串表示。
该方法接受一个阈值参数,将其转换为字符串表示。转换过程包括以下步骤:
"g" 是格式化字符串中的一种格式化选项,用于表示通用格式。它会根据阈值的类型自动选择合适的表示方式,并去除多余的零和小数点。具体来说,对于整数类型的阈值,它会显示为普通整数的形式,如 5、10、100 等。而对于浮点数类型的阈值,它会显示为一般的浮点数格式,如 0.5、1.0、2.5 等。在这个过程中,多余的零和小数点会被去除。
最后,该方法返回处理后的字符串表示形式。
该方法的作用是将阈值转换为特定的字符串表示形式,可能是为了后续的指标命名或其他需要使用特定格式的字符串的目的。由于该方法是私有方法(以单个下划线开头),它在类外部不可直接访问,只能在类内部被调用。
def setup(self, trainer, pl_module, stage):
# Setup scalar metrics
scalar_metrics = {}
scalar_metrics["mse"] = metrics.MeanSquaredError()
scalar_metrics["mae"] = metrics.MeanAverageError()
for threshold in self.thresholds:
csi = metrics.CriticalSuccessIndex(threshold=threshold)
scalar_metrics[f"csi_{self._threshold_str(threshold)}"] = csi
scalar_metrics["avg_csi"] = metrics.AverageCriticalSuccessIndex(
thresholds=self.thresholds
)
if self.probabilistic:
scalar_metrics["crps"] = metrics.ContinuousRankedProbabilityScore(
self.buckets
)
# Create metric collections and put metrics on module to automatically place on correct device
val_scalar_metrics = torchmetrics.MetricCollection(scalar_metrics)
pl_module.val_metrics = val_scalar_metrics.clone(prefix="val/")
# Lead time metrics
lead_time_metrics = {}
lead_time_metrics[f"mse"] = metrics.MeanSquaredError(
num_leadtimes=self.num_leadtimes
)
for threshold in self.thresholds:
csi = metrics.CriticalSuccessIndex(
threshold=threshold, num_leadtimes=self.num_leadtimes
)
lead_time_metrics[f"csi_{self._threshold_str(threshold)}"] = csi
lead_time_metrics["avg_csi"] = metrics.AverageCriticalSuccessIndex(
thresholds=self.thresholds, num_leadtimes=self.num_leadtimes
)
pl_module.lead_time_metrics = torchmetrics.MetricCollection(lead_time_metrics)在setup方法中,主要进行了以下操作:
scalar_metrics,用于存储标量指标。scalar_metrics字典中添加均方误差(MeanSquaredError)和平均绝对误差(MeanAverageError)指标。self.thresholds)循环遍历,为每个阈值创建关键成功指数(CriticalSuccessIndex)指标,并将其添加到scalar_metrics字典中。在添加时,指标的名称使用了f"csi_{self._threshold_str(threshold)}"的格式,其中self._threshold_str(threshold)将阈值转换为特定的字符串表示形式。scalar_metrics字典中,其中的阈值使用了阈值列表(self.thresholds)。self.probabilistic为True,则添加连续排名概率评分(ContinuousRankedProbabilityScore)指标到scalar_metrics字典中,其中的桶(buckets)参数使用了self.buckets。MetricCollection)并将指标放入模块(pl_module)中:MetricCollection是torchmetrics的一个方法,接收字典输入,创建指标集合。scalar_metrics字典创建标量指标集合(val_scalar_metrics)。val_scalar_metrics.clone(prefix="val/")创建一个带有前缀的克隆集合,前缀为"val/"。val_metrics属性,用于在验证过程中记录和计算指标。lead_time_metrics,用于存储引导时间指标。lead_time_metrics字典中添加均方误差(MeanSquaredError)指标,其中的引导时间数量使用了self.num_leadtimes。self.thresholds)循环遍历,为每个阈值创建引导时间关键成功指数(CriticalSuccessIndex)指标,并将其添加到lead_time_metrics字典中。在添加时,指标的名称使用了f"csi_{self._threshold_str(threshold)}"的格式,其中self._threshold_str(threshold)将阈值转换为特定的字符串表示形式。lead_time_metrics字典中,其中的阈值使用了阈值列表(self.thresholds)和引导时间数量(self.num_leadtimes)。lead_time_metrics)赋值给模块的lead_time_metrics属性,用于在验证过程中记录和计算引导时间指标。总的来说,setup方法主要用于设置回调函数中的指标,包括标量指标和引导时间指标。它创建了相应的指标对象,并将它们放入模块中,以便在训练过程中使用和记录。
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
):
"""Called after each validation batch with scalar and lead time metrics"""
_, label, metadata = batch
pl_module.val_metrics(outputs, label, metadata["target"]["mask"])
pl_module.lead_time_metrics(outputs, label, metadata["target"]["mask"])batch参数中解包获取到三个值,即_(不使用)、label和metadata。这些值通常代表了模型输出、标签和元数据等。pl_module)的val_metrics指标集合,传递模型输出、标签和目标掩码(metadata["target"]["mask"]),来计算标量指标的值。lead_time_metrics指标集合,传递模型输出、标签和目标掩码,来计算引导时间指标的值。通过调用指标集合的方法,可以将模型的输出、标签和目标掩码传递给指标集合,以便计算相应的指标值。这些指标值将用于后续的记录和评估过程。
def on_validation_epoch_end(self, trainer, pl_module):
# Log validation scalar metrics
pl_module.log_dict(
pl_module.val_metrics, on_step=False, on_epoch=True, sync_dist=True
)
# Compute and log lead time metrics
lead_time_metrics = pl_module.lead_time_metrics.compute()
lead_time_metrics_dict = {}
wandb_data = []
for metric_name, arr in lead_time_metrics.items():
# Add to logging dictionary
for leadtime, value in enumerate(arr):
lead_time_metrics_dict[f"val_time/{metric_name}_{leadtime+1}"] = value
# Save to file (tensorboard)
if self.logging == "tensorboard":
file_path = os.path.join(
pl_module.logger.log_dir, f"val_lead_time_{metric_name}.pt"
)
torch.save(arr.cpu(), file_path)
# Generate table for wandb
elif self.logging == "wandb":
columns = ["metric"] + [f"t_{i+1}" for i in range(len(arr))]
wandb_data.append([metric_name] + arr.tolist())
# Save table in wandb
if self.logging == "wandb":
pl_module.logger.log_table(
key="leadtimes", columns=columns, data=wandb_data
)
# Save lead time metrics over time
pl_module.log_dict(
lead_time_metrics_dict, on_step=False, on_epoch=True, sync_dist=True
)
pl_module.lead_time_metrics.reset()pl_module.val_metrics指标集合,通过调用模块的log_dict方法,将标量指标的值记录到日志中。on_step=False和on_epoch=True,以确保在验证周期结束时记录指标的值。sync_dist=True来同步跨多个设备的指标值。pl_module.lead_time_metrics指标集合的compute方法,计算引导时间指标的值。lead_time_metrics_dict,用于存储引导时间指标的名称和值。wandb_data,用于存储生成表格所需的数据。lead_time_metrics_dict字典中,以便后续的记录和保存。self.logging为"tensorboard",则将引导时间指标的值保存到文件中,文件名为val_lead_time_{metric_name}.pt。self.logging为"wandb",则生成一个表格所需的数据,其中包括指标名称和对应的值。self.logging为"wandb",则将生成的表格数据使用pl_module.logger.log_table方法保存到wandb中,其中的key表示表格的唯一标识,columns表示表格的列名,data表示表格的数据。pl_module的log_dict方法,将引导时间指标的名称和值记录到日志中。on_step=False和on_epoch=True,以确保在验证周期结束时记录指标的值。sync_dist=True来同步跨多个设备的指标值。pl_module.lead_time_metrics指标集合的reset方法,重置引导时间指标的状态,以便在下一个验证周期开始时重新计算。
torchmetrics中的Metric类,重写了full_state_update和higher_is_better两个属性、update和compute两个方法。在类的定义中,full_state_update被设置为False,表示不需要完全状态更新;higher_is_better被设置为True,表示指标的值越高越好。
在PyTorch的Metric类中,通常会定义一些状态变量,用于保存指标计算过程中的中间结果。这些状态变量可以在每次更新指标时被更新。而完全状态更新是指每次更新指标时,都会将所有的状态变量进行更新。然而,并不是所有的指标都需要进行完全状态更新。有些指标的计算只依赖于最近一次更新的状态,而不需要考虑之前的状态。在这种情况下,可以将full_state_update设置为False,以优化计算性能。这次计算的CSI指标跟之前的状态就无关,因此不需要完全状态更新。
在update方法中,接受了三个参数prediction、label和mask,用于更新指标的计算。根据阈值列表和预测结果,将预测结果转换为二进制形式,并根据reduce_time的值进行不同的操作。
在compute方法中,计算了平均关键成功指数(CSI),即真阳性(true positives)除以真阳性和假预测(false guesses)之和的平均值。

thresholds和num_leadtimes的值,选择不同的默认值和设置self.reduce_time的值。num_leadtimes为None或者等于1,表示只有一个时间步,那么默认值default将被设置为一个形状为(len(thresholds),)的全零张量,并且self.reduce_time将被设置为True,表示需要减少时间维度。num_leadtimes大于1,表示有多个时间步,那么默认值default将被设置为一个形状为(len(thresholds), num_leadtimes)的全零张量,并且self.reduce_time将被设置为False,表示不需要减少时间维度。num_leadtimes小于等于0,则会抛出ValueError异常,提示num_leadtimes必须大于0。thresholds参数赋值给self.thresholds属性,以便在后续的计算中使用。self.add_state方法,将名为"true_positives"和"false_guesses"的状态变量添加到指标类中。这两个状态变量的默认值都是通过default.clone()来创建的,同时设置了分布式合并函数dist_reduce_fx为"sum"dist_reduce_fx,Metric类中使用分布式合并函数的目的是支持在分布式计算环境中进行指标的计算和合并,在分布式计算环境中,通常有多个计算节点或进程同时进行计算任务。每个节点或进程都可能独立地计算指标的一部分,并生成局部的状态变量。为了得到整体的指标结果,需要将各个节点或进程上计算得到的状态变量进行合并。更新状态变量。

首先,根据阈值列表self.thresholds,使用enumerate函数遍历阈值列表的索引和值,因为CSI指标的计算在不同thresholds下是不同的。
接下来,将预测结果prediction的强度(intensity)赋给变量pred。
然后,将pred和label转换为二进制形式。将pred中大于等于当前thresholds的元素设置为真(True),其余为假(False)。同样,将label中大于等于当前thresholds的元素设置为真,其余为假。
接着,根据self.reduce_time的值进行不同的操作。
self.reduce_time为True,表示只有一个时间步,那么将根据mask对pred和lab进行掩码操作,即将掩码为真(True)的位置从pred和lab中剔除。self.reduce_time为False,表示有多个时间步,那么通过重新排列张量的维度,将pred和lab的时间维度放到最后的位置,即将形状由"b c t h w"变为"(b c h w) t"。同时,对mask进行相同的重新排列操作,并使用torch.logical_and函数将pred和lab与掩码取反(~m)进行逻辑与操作,以将掩码位置视为真(True)。这样可以保留其他维度的信息并考虑掩码。最后,根据预测结果和标签计算真阳性(true positives)和假预测(false guesses)的总数。使用torch.logical_and函数计算pred和lab的逻辑与,得到同时为真的位置,然后使用sum(dim=0)对每个时间步的结果进行求和,将结果累加到self.true_positives[i]中。使用(pred != lab)进行逻辑不等于操作,得到不一致的位置,然后使用sum(dim=0)对每个时间步的结果进行求和,将结果累加到self.false_guesses[i]中。
通过循环遍历阈值列表和计算真阳性和假预测的总数,update方法更新了指标类中的状态变量。
根据状态变量计算最终指标。

各种分箱策略。
继承自yaml库的SafeLoader类,用于解析YAML文件(里没事各种参数设定)。
用于各种数据处理。使用的情况有:
train.py:
from w4c23.utils.data_utils import get_cuda_memory_usage, tensor_to_submission_file
sampler.py:
from w4c23.utils.data_utils import get_file
w4c_dataloader.py:
from w4c23.utils.data_utils import *数据集中样本的抽样策略,实现重要性采样。
读取并归一化大赛数据。
保存模型参数。
保存定义模型的各种参数组合。
原始数据。
2D U-Net 架构的输出与其输入具有相同的空间维度。这意味着对于大小为 128 x 128 像素的输入序列,通过 U-Net 的前向传播将生成大小为 128 x 128 像素的输出。标签对应于大小为 42 x 42 像素的中心块。因此,为了指导降水临近预报模型,我们采用中央 42 x 42 像素块并上采样到 252 x 252 像素标签。这种裁剪和上采样是在 MetNet 9 中引入的,这是由于输入和标签的空间分辨率不同所致,如第 3 节中所述。
解释为什么要对标签值上采样。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。