首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >X-SAM:从“分割万物”到“万物皆可分” | AAAI2026

X-SAM:从“分割万物”到“万物皆可分” | AAAI2026

原创
作者头像
AI小怪兽
发布2025-11-17 09:31:11
发布2025-11-17 09:31:11
2120
举报
文章被收录于专栏:毕业设计毕业设计YOLO大作战

📄 论文核心摘要:

我们介绍了X-SAM,一个创新的框架,它统一了多样化的图像分割任务,将分割范式从"分割任何事物"扩展到了"任何分割"。为实现此目标,我们的方法解决了三个关键的技术挑战:(1)任务形式化:将SAM转变为具有跨任务适用性的通用分割架构。(2)模态增强:增强LLMs以具备多模态输入处理能力。(3)统一框架:开发一种有效促进跨不同领域的全面分割应用的连贯方法。

博主简介

AI小怪兽 | 计算机视觉布道者 | 视觉检测领域创新者

深耕计算机视觉与深度学习领域,专注于视觉检测前沿技术的探索与突破。长期致力于YOLO系列算法的结构性创新、性能极限优化与工业级落地实践,旨在打通从学术研究到产业应用的最后一公里。

🚀 核心专长与技术创新

  • YOLO算法结构性创新:于CSDN平台原创发布《YOLOv13魔术师》、《YOLOv12魔术师》等全系列深度专栏。系统性提出并开源了多项原创自研模块,在模型轻量化设计、多维度注意力机制融合、特征金字塔重构等关键方向完成了一系列突破性实践,为行业提供了具备高参考价值的技术路径与完整解决方案。
  • 技术生态建设与知识传播:独立运营 “计算机视觉大作战” 公众号(粉丝1.6万),成功构建高质量的技术交流社群。致力于将复杂算法转化为通俗易懂的解读与可复现的工程代码,显著降低了计算机视觉的技术入门门槛。

🏆 行业影响力与商业实践

  • 荣获腾讯云年度影响力作者创作之星奖项,内容质量与专业性获行业权威平台认证。
  • 全网累计拥有 7万+ 垂直领域技术受众,专栏文章总阅读量突破百万,在目标检测领域形成了广泛的学术与工业影响力。
  • 具备丰富的企业级项目交付经验,曾为工业视觉检测、智慧城市安防等多个关键领域提供定制化的算法模型与解决方案,驱动业务智能化升级。

💡 未来方向与使命

秉持 “让每一行代码都有温度” 的技术理念,未来将持续聚焦于实时检测、语义分割及工业缺陷检测的商业化闭环等核心方向。愿与业界同仁协同创新,共同推动技术边界,以坚实的技术能力赋能实体经济与行业变革。

原理介绍

论文:https://arxiv.org/pdf/2508.04655

代码:https://github.com/wanghao9610/X-SAM

核心代码:

X-SAM/xsam/xsam/model/xsam.py

代码语言:txt
复制



class XSamModel(BaseModel):
    def __init__(
        self,
        llm=None,
        tokenizer=None,
        visual_encoder=None,
        postprocess_fn=None,
        segmentor=None,
        special_tokens=None,
        freeze_llm=False,
        freeze_visual_encoder=False,
        freeze_segmentor_encoder=False,
        freeze_segmentor_connector=False,
        visual_select_layer=-2,
        visual_select_indx=0,  # 1 for clip, 0 for siglip
        seg_select_layers=[8, 16, 24, 32],
        extract_seg_embeds=True,
        s1_pretrained_pth=None,
        s2_pretrained_pth=None,
        projector_depth=2,
        downsample_ratio=0.5,
        llm_lora=None,
        visual_encoder_lora=None,
        segmentor_lora=None,
        connector_type=None,
        connector_hidden_dim=256,
        connector_scale_factor=[4, 2, 1, 0.5],
        sampler_type="naive",
        sampler_input_feat="seg_pixel_values",
        cond_type: Literal["phrase", "cls", "all"] = "phrase",
        use_dual_encoder=False,
        use_vision_sampler=False,
        use_activation_checkpointing=True,
        max_position_embeddings=None,
        llm_loss_weight: float = 1.0,
        seg_loss_weight: float = 1.0,
    ):
        super().__init__()
        self.freeze_llm = freeze_llm
        self.freeze_visual_encoder = freeze_visual_encoder
        self.freeze_segmentor_encoder = freeze_segmentor_encoder
        self.freeze_segmentor_connector = freeze_segmentor_connector

        assert (
            llm is not None or visual_encoder is not None or segmentor is not None
        ), "llm, visual_encoder, and segmentor cannot be all None"

        if isinstance(llm, dict):
            llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
        self.llm = self._build_from_cfg_or_module(llm)
        self.tokenizer = self._build_from_cfg_or_module(tokenizer)
        self.visual_encoder = self._build_from_cfg_or_module(visual_encoder)
        self.segmentor = self._build_from_cfg_or_module(segmentor)

        if self.llm is not None:
            self.llm.config.use_cache = False
            dispatch_modules(self.llm)

        self.postprocess_fn = postprocess_fn
        if special_tokens is not None:
            self._add_special_tokens(special_tokens)

        if self.visual_encoder is not None:
            self.projector_depth = projector_depth
            visual_projector_config = DynamicProjectorConfig(
                visual_hidden_size=self.visual_encoder.config.hidden_size,
                llm_hidden_size=self.llm.config.hidden_size,
                depth=self.projector_depth,
            )
            self.visual_projector = DynamicProjectorModel(visual_projector_config).to(self.visual_encoder.dtype)

        if self.segmentor is not None:
            if self.llm is not None and self.segmentor.decoder is not None:
                llm_projector_config = DynamicProjectorConfig(
                    visual_hidden_size=self.llm.config.hidden_size,
                    llm_hidden_size=self.segmentor.dec_config.hidden_size,
                    depth=self.projector_depth,
                )
                self.llm_projector = DynamicProjectorModel(llm_projector_config).to(self.llm.dtype)

            if self.segmentor.encoder is not None and use_dual_encoder:
                seg_projector_config = DynamicProjectorConfig(
                    visual_hidden_size=self.segmentor.enc_config.hidden_size,
                    llm_hidden_size=self.llm.config.hidden_size,
                    downsample_ratio=downsample_ratio,
                    depth=self.projector_depth,
                )
                self.seg_projector = DynamicProjectorModel(seg_projector_config).to(self.segmentor.dtype)

            if self.segmentor.pixel_decoder is not None and extract_seg_embeds and connector_type is not None:
                seg_select_layers = seg_select_layers[-self.segmentor.dec_config.num_feature_levels :]
                connector_config = ConnectorConfig(
                    segmentor_encoder_channels=[self.segmentor.enc_config.hidden_size]
                    * self.segmentor.dec_config.num_feature_levels,
                    hidden_channels=connector_hidden_dim,
                    scale_factor=connector_scale_factor[-self.segmentor.dec_config.num_feature_levels :],
                    connector_type=connector_type,
                )
                self.seg_connector = ConnectorModel(connector_config).to(self.segmentor.dtype)

            if self.segmentor.decoder is not None and use_vision_sampler:
                sampler_config = SamplerConfig(
                    sampler_type=sampler_type,
                    num_sample_point=256,
                    input_dim=self.llm.config.hidden_size,
                    output_dim=self.segmentor.dec_config.hidden_size,
                )
                self.vision_sampler = SamplerModel(sampler_config).to(self.segmentor.dtype)

            if self.segmentor.decoder is not None and self.segmentor.open_cls:
                self.bg_embeds = nn.Embedding(1, self.segmentor.dec_config.hidden_size).to(self.segmentor.dtype)

        if self.freeze_llm and self.llm is not None:
            self.llm.requires_grad_(False)
        if self.freeze_visual_encoder and self.visual_encoder is not None:
            self.visual_encoder.requires_grad_(False)
        if self.freeze_segmentor_encoder and self.segmentor is not None:
            self.segmentor.encoder.requires_grad_(False)
        if self.freeze_segmentor_connector and self.segmentor is not None:
            self.seg_connector.requires_grad_(False)

        if use_activation_checkpointing:
            # For backward compatibility
            if self.llm is not None:
                if hasattr(self.llm, "enable_input_require_grads"):
                    self.llm.enable_input_require_grads()
                else:
                    self.llm.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

            if self.visual_encoder is not None:
                if hasattr(self.visual_encoder, "enable_input_require_grads"):
                    self.visual_encoder.enable_input_require_grads()
                else:
                    self.visual_encoder.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
                self.visual_projector.enable_input_require_grads()

            if self.segmentor is not None:
                if hasattr(self.segmentor, "enable_input_require_grads"):
                    self.segmentor.enable_input_require_grads()
                else:
                    self.segmentor.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
                if hasattr(self, "seg_projector"):
                    self.seg_projector.enable_input_require_grads()
                if hasattr(self, "llm_projector"):
                    self.llm_projector.enable_input_require_grads()
                if hasattr(self, "seg_connector"):
                    self.seg_connector.enable_input_require_grads()
            # enable gradient (activation) checkpointing for memory efficiency
            self.gradient_checkpointing_enable()
        else:
            self.gradient_checkpointing_disable()

        self.use_llm_lora = llm_lora is not None
        self.use_visual_encoder_lora = visual_encoder_lora is not None
        self.use_segmentor_encoder_lora = segmentor_lora is not None

        if self.use_llm_lora:
            self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
        if self.use_visual_encoder_lora:
            self._prepare_visual_encoder_for_lora(visual_encoder_lora, use_activation_checkpointing)
        if self.use_segmentor_encoder_lora:
            self._prepare_segmentor_for_lora(segmentor_lora, use_activation_checkpointing)

        state_dict = super().state_dict()
        if s1_pretrained_pth is not None:
            pretrained_state_dict = guess_load_checkpoint(s1_pretrained_pth)
            self.load_state_dict(pretrained_state_dict, strict=False)

            matched_keys = [k for k in pretrained_state_dict.keys() if k in state_dict.keys()]
            print_log(f"Load s1_pretrained_pth from {s1_pretrained_pth}", logger="current")
            print_log(f"Matched keys: {len(matched_keys)} / {len(pretrained_state_dict.keys())}", logger="current")

        if s2_pretrained_pth is not None:
            pretrained_state_dict = guess_load_checkpoint(s2_pretrained_pth)
            self.load_state_dict(pretrained_state_dict, strict=False)

            matched_keys = [k for k in pretrained_state_dict.keys() if k in state_dict.keys()]
            print_log(f"Load s2_pretrained_pth from {s2_pretrained_pth}", logger="current")
            print_log(f"Matched keys: {len(matched_keys)} / {len(pretrained_state_dict.keys())}", logger="current")

        self.visual_select_layer = visual_select_layer
        self.visual_select_indx = visual_select_indx
        self.seg_select_layers = seg_select_layers
        self.extract_seg_embeds = extract_seg_embeds
        self.sampler_input_feat = sampler_input_feat
        self.cond_type = cond_type
        self.llm_loss_weight = llm_loss_weight
        self.seg_loss_weight = seg_loss_weight

    @property
    def device(self):
        return get_device()

    @property
    def dtype(self):
        """
        `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
        """
        return get_parameter_dtype(self)

    def _add_special_tokens(self, special_tokens):
        assert all(
            token in [DEFAULT_SEG_TOKEN, DEFAULT_PSTART_TOKEN, DEFAULT_PEND_TOKEN, DEFAULT_CLS_TOKEN]
            for token in special_tokens
        )
        num_new_tokens = self.tokenizer.add_tokens(special_tokens, special_tokens=True)
        if num_new_tokens > 0:
            self.llm.resize_token_embeddings(len(self.tokenizer))

        self.seg_token_idx = -1
        self.cls_token_idx = -1
        self.pstart_token_idx = -1
        self.pend_token_idx = -1

        if DEFAULT_SEG_TOKEN in special_tokens:
            self.seg_token_idx = self.tokenizer(DEFAULT_SEG_TOKEN, add_special_tokens=False)["input_ids"][0]
        if DEFAULT_CLS_TOKEN in special_tokens:
            self.cls_token_idx = self.tokenizer(DEFAULT_CLS_TOKEN, add_special_tokens=False)["input_ids"][0]
        if DEFAULT_PSTART_TOKEN in special_tokens:
            self.pstart_token_idx = self.tokenizer(DEFAULT_PSTART_TOKEN, add_special_tokens=False)["input_ids"][0]
        if DEFAULT_PEND_TOKEN in special_tokens:
            self.pend_token_idx = self.tokenizer(DEFAULT_PEND_TOKEN, add_special_tokens=False)["input_ids"][0]

    def _get_index_embeds(self, input_embeds, embed_ids):
        output_embeds = []
        for input_embed, embed_id in zip(input_embeds, embed_ids):
            unique_ids = torch.unique(embed_id[embed_id != -1])
            if len(unique_ids) == 0:
                continue

            embeds = torch.stack([input_embed[embed_id == idx].mean(dim=0) for idx in unique_ids])
            output_embeds.append(embeds)

        return output_embeds if len(output_embeds) > 0 else None

    def _process_embeds(self, cond_embeds, seg_embeds, task_name="genseg"):
        B = len(cond_embeds)
        embed_masks = None
        local_cond_lens = None
        global_cond_lens = None
        bg_embeds = self.bg_embeds.weight
        if task_name in ["genseg", "vgdseg", "gcgseg", "ovseg", "interseg"]:
            max_cond_len = max([x.shape[0] for x in cond_embeds])
            embed_masks = []
            for i, cond_embed in enumerate(cond_embeds):
                cond_embeds[i] = torch.cat(
                    [cond_embed, bg_embeds.clone().repeat(max_cond_len - cond_embed.shape[0], 1) + -1e9],
                    dim=0,
                )
                embed_masks.append(
                    torch.cat(
                        [
                            torch.ones(cond_embed.shape[0], device=cond_embed.device),
                            torch.zeros(max_cond_len - cond_embed.shape[0], device=cond_embed.device),
                        ]
                    )
                )
            bg_embeds = bg_embeds[None, ...].repeat(B, 1, 1)
            cond_embeds = torch.cat([torch.stack(cond_embeds), bg_embeds], dim=1)
            seg_embeds = torch.stack(seg_embeds) if seg_embeds is not None else None
            embed_masks = torch.cat([torch.stack(embed_masks), torch.ones((B, 1), device=cond_embeds.device)], dim=1)
        elif task_name in ["refseg", "reaseg"]:
            local_cond_lens = [x.shape[0] for x in cond_embeds]
            cond_embeds = torch.cat([torch.cat(cond_embeds), bg_embeds])
            cond_embeds = cond_embeds[None, ...].repeat(sum(local_cond_lens), 1, 1)
            seg_embeds = torch.cat(seg_embeds).unsqueeze(1) if seg_embeds is not None else None
        else:
            raise ValueError(f"Task name {task_name} is not supported in _process_embeds")

        return cond_embeds, seg_embeds, embed_masks, local_cond_lens, global_cond_lens

    def _get_vgd_labels(self, data_samples):
        def _get_attr_from_data_samples(data_samples, attr):
            return getattr(data_samples, attr, None) if data_samples is not None else None

        class_labels = _get_attr_from_data_samples(data_samples, "class_labels")
        sampled_labels = _get_attr_from_data_samples(data_samples, "sampled_labels")
        contiguous_labels = _get_attr_from_data_samples(data_samples, "contiguous_labels")

        if class_labels is not None:
            class_labels = [class_label.cpu().numpy().tolist() for class_label in class_labels]

        if contiguous_labels is not None:
            # convert labels to contiguous labels
            assert class_labels is not None and sampled_labels is not None
            class_labels = [
                [ordered_label.index(sampled_label[label]) for label in class_label]
                for ordered_label, sampled_label, class_label in zip(contiguous_labels, sampled_labels, class_labels)
            ]
            sampled_labels = [
                [ordered_label.index(label) for label in sampled_label]
                for ordered_label, sampled_label in zip(contiguous_labels, sampled_labels)
            ]
        return class_labels, sampled_labels

    def _get_vprompt_feats_and_masks(
        self, vprompt_feats, vprompt_masks, class_labels, contiguous_labels, sampled_labels
    ):
        sampled_feats = []
        sampled_masks = []
        new_sampled_labels = []

        # Process each batch
        for batch_idx, (
            batch_vprompt_feats,
            batch_vprompt_masks,
            batch_class_labels,
            batch_contiguous_labels,
        ) in enumerate(zip(vprompt_feats, vprompt_masks, class_labels, contiguous_labels)):
            batch_sampled_feats = torch.zeros(
                (len(batch_contiguous_labels), batch_vprompt_feats.shape[1]),
                dtype=batch_vprompt_feats.dtype,
                device=batch_vprompt_feats.device,
            )
            batch_sampled_masks = torch.zeros(
                (len(batch_contiguous_labels), batch_vprompt_masks.shape[1], batch_vprompt_masks.shape[2]),
                dtype=batch_vprompt_masks.dtype,
                device=batch_vprompt_masks.device,
            )
            new_batch_sampled_labels = []

            # Track used labels to avoid duplicate sampling
            used_labels = []
            used_poses = []

            for i, target_label in enumerate(batch_contiguous_labels):
                # Find matching positions across all batches
                pos_matches = [
                    (b_idx, pos)
                    for b_idx, batch_labels in enumerate(class_labels)
                    for pos, label in enumerate(batch_labels)
                    if label == target_label and (b_idx, pos) not in used_poses
                ]
                neg_matches = [
                    (b_idx, pos)
                    for b_idx, batch_labels in enumerate(class_labels)
                    for pos, label in enumerate(batch_labels)
                    if label not in used_labels and (b_idx, pos) not in used_poses and label not in batch_class_labels
                ]

                matches = pos_matches if pos_matches else neg_matches

                if matches:
                    selected_batch, selected_pos = matches[torch.randint(len(matches), (1,)).item()]
                    batch_sampled_feats[i] = vprompt_feats[selected_batch][selected_pos]
                    batch_sampled_masks[i] = vprompt_masks[selected_batch][selected_pos]
                    new_batch_sampled_labels.append(
                        sampled_labels[selected_batch][
                            contiguous_labels[selected_batch].index(class_labels[selected_batch][selected_pos])
                        ]
                    )
                    used_labels.append(class_labels[selected_batch][selected_pos])
                    used_poses.append((selected_batch, selected_pos))
                else:
                    # If no matches found, use default embedding
                    batch_sampled_feats[i] = torch.zeros_like(batch_vprompt_feats[0])
                    batch_sampled_masks[i] = torch.zeros_like(batch_vprompt_masks[0])
                    new_batch_sampled_labels.append(-1)

            sampled_feats.append(batch_sampled_feats)
            sampled_masks.append(batch_sampled_masks)
            new_sampled_labels.append(new_batch_sampled_labels)

        return sampled_feats, sampled_masks, new_sampled_labels

    def _get_attrs_from_data_samples(self, data_samples, attrs, **kwargs):
        if isinstance(attrs, str):
            attrs = [attrs]
        return [getattr(data_samples, attr, None) if data_samples is not None else None for attr in attrs]

    def forward(self, data_dict, data_samples=None, mode="loss", **kwargs):
        if data_samples is not None:
            data_samples = data_sample_to_device(data_samples, device=get_device())

        extra_data_dict = {}
        if "pixel_values" in data_dict and self.visual_encoder is not None:
            visual_outputs = self.visual_encoder(
                data_dict["pixel_values"].to(self.visual_encoder.dtype),
                output_hidden_states=True,
            )
            pixel_values = self.visual_projector(
                visual_outputs.hidden_states[self.visual_select_layer][:, self.visual_select_indx :]
            )
            data_dict["pixel_values"] = pixel_values.to(self.llm.dtype)

        if "seg_pixel_values" in data_dict and self.segmentor is not None:
            if self.extract_seg_embeds:
                seg_visual_outputs = self.segmentor.encoder(
                    data_dict["seg_pixel_values"].to(self.segmentor.dtype),
                    output_hidden_states=True,
                    output_attentions=False,
                )
                seg_image_embeddings = (
                    seg_visual_outputs.last_hidden_state
                    if hasattr(seg_visual_outputs, "last_hidden_state")
                    else seg_visual_outputs.hidden_states[-1].transpose(1, 2)
                )
                seg_pixel_values = None
                if hasattr(self, "seg_projector"):
                    seg_pixel_values = self.seg_projector(seg_visual_outputs.hidden_states[self.visual_select_layer])
                    seg_pixel_values = seg_pixel_values.to(self.llm.dtype)

                if hasattr(self, "seg_connector"):
                    seg_image_embeddings = self.seg_connector(
                        [seg_visual_outputs.hidden_states[i] for i in self.seg_select_layers]
                    )
                elif self.segmentor.pixel_decoder is not None and hasattr(seg_visual_outputs, "feature_maps"):
                    seg_image_embeddings = seg_visual_outputs.feature_maps

                # here, seg_pixel_values is seg_projector output
                data_dict["seg_pixel_values"] = seg_pixel_values
                extra_data_dict = {
                    "seg_pixel_values": None,
                    "seg_image_embeddings": seg_image_embeddings,
                }
                del seg_visual_outputs
            else:
                # here, seg_pixel_values is image_processor output
                extra_data_dict = {
                    "seg_pixel_values": data_dict["seg_pixel_values"].to(self.segmentor.dtype),
                    "seg_image_embeddings": None,
                }
                data_dict["seg_pixel_values"] = None
        else:
            data_dict["seg_pixel_values"] = None

        if data_dict.get("vprompt_masks", None) is not None and hasattr(self, "vision_sampler"):
            vprompt_masks = data_dict.pop("vprompt_masks")
            class_labels, contiguous_labels = self._get_vgd_labels(data_samples)
            sampled_labels = self._get_attrs_from_data_samples(data_samples, ["sampled_labels"])[0]
            sampled_feats = self.vision_sampler(data_dict[self.sampler_input_feat], vprompt_masks)
            assert all(
                sampled_feat is not None for sampled_feat in sampled_feats
            ), f"{data_dict[self.sampler_input_feat]}, {vprompt_masks}"
            vprompt_feats, vprompt_masks, new_sampled_labels = self._get_vprompt_feats_and_masks(
                sampled_feats, vprompt_masks, class_labels, contiguous_labels, sampled_labels
            )
            data_dict["vprompt_feats"] = vprompt_feats
            kwargs["vprompt_masks"] = vprompt_masks
            kwargs["sampled_labels"] = sampled_labels

        if self.llm is not None:
            data_dict = prepare_inputs_labels_for_multimodal(llm=self.llm, **data_dict)

        data_dict.update(extra_data_dict)

        if mode == "loss":
            return self.compute_loss(data_dict, data_samples, **kwargs)
        elif mode == "predict":
            return self.predict(data_dict, data_samples, **kwargs)
        elif mode == "tensor":
            return self._forward(data_dict, data_samples, **kwargs)
        else:
            raise NotImplementedError

    def _forward(
        self,
        data_dict,
        data_samples=None,
        **kwargs,
    ):
        if data_dict.get("inputs_embeds", None) is not None:
            data_dict["input_ids"] = None

        cond_ids = data_dict.pop("cond_ids", None)
        seg_ids = data_dict.pop("seg_ids", None)
        seg_pixel_values = data_dict.pop("seg_pixel_values", None)
        seg_image_embeddings = data_dict.pop("seg_image_embeddings", None)
        task_names, image_size, scaled_size, mask_labels, class_labels = self._get_attrs_from_data_samples(
            data_samples,
            [
                "task_names",
                "image_sizes",
                "scaled_sizes",
                "mask_labels",
                "class_labels",
            ],
            **kwargs,
        )
        task_names = task_names if task_names is not None else ["genseg"]
        assert (
            len(set(task_names)) == 1 and task_names[0] in DEFAULT_TASKS
        ), f"Task name {task_names} is not in {DEFAULT_TASKS}"

        seg_embeds = None
        cond_embeds = None
        embed_masks = None
        llm_outputs = None
        seg_outputs = None
        local_cond_lens = None
        global_cond_lens = None

        if self.llm is not None:
            llm_outputs = self.llm(**data_dict, output_hidden_states=True)

        if self.segmentor is None or self.segmentor.decoder is None:
            return llm_outputs, None

        if llm_outputs is not None:
            llm_hidden_states = llm_outputs.hidden_states
            llm_last_hidden_state = llm_hidden_states[-1]
            llm_embeds = self.llm_projector(llm_last_hidden_state)
            if cond_ids is not None:
                cond_embeds = self._get_index_embeds(llm_embeds, cond_ids)
            if seg_ids is not None:
                seg_embeds = self._get_index_embeds(llm_embeds, seg_ids)
            if cond_embeds is not None and seg_embeds is not None:
                cond_embeds, seg_embeds, embed_masks, local_cond_lens, global_cond_lens = self._process_embeds(
                    cond_embeds, seg_embeds, task_names[0]
                )

        if (local_cond_lens or global_cond_lens) is not None and mask_labels is not None:
            cur_rank = get_rank()
            mask_labels = list(chain(*[mask_label.split(1) for mask_label in mask_labels]))
            if global_cond_lens is not None:
                label_offsets = (
                    list(accumulate([sum(torch.cat(global_cond_lens[:cur_rank])).item()] + local_cond_lens[:-1]))
                    if cur_rank > 0
                    else list(accumulate([0] + local_cond_lens[:-1]))
                )
            else:
                label_offsets = list(accumulate([0] + local_cond_lens[:-1]))

            class_labels = list(
                chain(
                    *[
                        (class_label + label_offset).split(1)
                        for label_offset, class_label in zip(label_offsets, class_labels)
                    ]
                )
            )

        if seg_embeds is not None or llm_outputs is None:
            if seg_embeds is not None and seg_embeds.shape[1] != 1:
                seg_outputs = None
            else:
                seg_outputs = self.segmentor(
                    pixel_values=seg_pixel_values,
                    image_embeddings=seg_image_embeddings,
                    cond_embeddings=cond_embeds,
                    seg_embeddings=seg_embeds,
                    embed_masks=embed_masks,
                    mask_labels=mask_labels,
                    class_labels=class_labels,
                    cond_lens=local_cond_lens,
                    return_dict=True,
                )
                if kwargs.pop("do_postprocess", False):
                    seg_outputs = self.postprocess_fn(
                        seg_outputs,
                        image_sizes=image_size,
                        scaled_sizes=scaled_size,
                        **kwargs,
                    )

        return llm_outputs, seg_outputs

    @torch.no_grad()
    def predict(self, data_dict, data_samples=None, **kwargs):
        if data_dict.get("inputs_embeds", None) is not None:
            data_dict["input_ids"] = None

        if data_dict.get("labels", None) is not None:
            data_dict["labels"] = None

        if data_dict.get("position_ids", None) is not None:
            data_dict["position_ids"] = None

        if data_dict.get("attention_mask", None) is not None:
            data_dict["attention_mask"] = None

        seg_ids = data_dict.pop("seg_ids", None)
        seg_pixel_values = data_dict.pop("seg_pixel_values", None)
        seg_image_embeddings = data_dict.pop("seg_image_embeddings", None)
        input_cond_ids = data_dict.pop("cond_ids", None)
        task_names, image_size, scaled_size = self._get_attrs_from_data_samples(
            data_samples,
            ["task_names", "image_sizes", "scaled_sizes"],
            **kwargs,
        )
        task_names = task_names if task_names is not None else ["genseg"]
        assert (
            len(task_names) == 1 and task_names[0] in DEFAULT_TASKS
        ), f"Task name {task_names} is not in {DEFAULT_TASKS}"

        generation_config = kwargs.pop("generation_config", None)
        stopping_criteria = kwargs.pop("stopping_criteria", None)

        seg_embeds = None
        cond_embeds = None
        llm_outputs = None
        seg_outputs = None
        local_cond_lens = None

        if self.llm is not None:
            llm_outputs = self.llm.generate(
                **data_dict,
                return_dict_in_generate=True,
                output_hidden_states=True,
                generation_config=generation_config,
                stopping_criteria=stopping_criteria,
            )

        if self.segmentor is None or self.segmentor.decoder is None:
            return llm_outputs, None

        if llm_outputs is not None:
            llm_output_ids = llm_outputs.sequences
            llm_hidden_states = llm_outputs.hidden_states
            input_hidden_states = llm_hidden_states[0][-1]
            llm_last_hidden_state = torch.cat([x[-1] for x in llm_hidden_states], dim=1)
            llm_input_embeds = self.llm_projector(input_hidden_states)
            llm_output_embeds = self.llm_projector(llm_last_hidden_state)

            L = input_hidden_states.shape[1]
            if input_cond_ids is not None:
                cond_embeds = self._get_index_embeds(llm_input_embeds, input_cond_ids)

            # update cond_embeds if there is pstart and pend token in the output
            pstart_idx = (llm_output_ids[..., :-1] == self.pstart_token_idx).nonzero()[:, 1]
            pend_idx = (llm_output_ids[..., :-1] == self.pend_token_idx).nonzero()[:, 1]
            cls_idx = (llm_output_ids[..., :-1] == self.cls_token_idx).nonzero()[:, 1]
            if len(pstart_idx) > 0 or len(cls_idx) > 0:
                output_cond_ids = torch.full(
                    llm_last_hidden_state.shape[:2], -1, dtype=torch.long, device=input_hidden_states.device
                )
                shift = llm_input_embeds.shape[1]
                if self.cond_type in ["phrase", "all"]:
                    for i, (pstart, pend) in enumerate(zip(pstart_idx, pend_idx)):
                        output_cond_ids[:, shift + pstart : shift + pend + 1] = i
                if self.cond_type in ["cls", "all"]:
                    for i, ci in enumerate(cls_idx):
                        output_cond_ids[:, shift + ci] = i

                cond_embeds = self._get_index_embeds(llm_output_embeds, output_cond_ids)

            # update seg_ids if there is seg token in the output
            seg_idx = (llm_output_ids[..., :-1] == self.seg_token_idx).nonzero()[:, 1]
            if len(seg_idx) > 0:
                # fmt: off
                B = (seg_image_embeddings.shape[0] if isinstance(seg_image_embeddings, torch.Tensor) 
                    else seg_image_embeddings[0].shape[0]) if self.extract_seg_embeds else seg_pixel_values.shape[0]
                assert B == 1, "Only support batch size 1 for prediction"
                # fmt: on
                seg_ids = torch.full_like(
                    llm_output_ids[..., :-1], -1, dtype=torch.long, device=input_hidden_states.device
                )
                for i, idx in enumerate(seg_idx):
                    seg_ids[:, idx] = i
                seg_ids = torch.cat(
                    [torch.full((B, L), -1, dtype=torch.long, device=input_hidden_states.device), seg_ids], dim=-1
                )
                seg_embeds = self._get_index_embeds(llm_output_embeds, seg_ids)

            if cond_embeds is not None and seg_embeds is not None:
                cond_embeds, seg_embeds, embed_masks, local_cond_lens, _ = self._process_embeds(
                    cond_embeds, seg_embeds, task_names[0]
                )

        if (cond_embeds is not None and seg_embeds is not None) or llm_outputs is None:
            if seg_embeds is not None and seg_embeds.shape[1] != 1:
                seg_outputs = None
            else:
                seg_outputs = self.segmentor(
                    pixel_values=seg_pixel_values,
                    image_embeddings=seg_image_embeddings,
                    cond_embeddings=cond_embeds,
                    seg_embeddings=seg_embeds,
                    embed_masks=embed_masks,
                    cond_lens=local_cond_lens,
                    return_dict=True,
                )
                if kwargs.pop("do_postprocess", True):
                    seg_outputs = self.postprocess_fn(
                        seg_outputs,
                        image_sizes=image_size,
                        scaled_sizes=scaled_size,
                        **kwargs,
                    )
        return llm_outputs, seg_outputs

    def compute_loss(self, data_dict, data_samples=None, **kwargs):
        llm_outputs, seg_outputs = self._forward(data_dict, data_samples, **kwargs)
        loss, loss_llm, loss_seg = 0.0, 0.0, 0.0
        if llm_outputs is not None and seg_outputs is None:
            loss_llm = llm_outputs.loss * self.llm_loss_weight
            loss = loss_llm
            loss_dict = {"loss": loss, "loss_llm": loss_llm}
        elif llm_outputs is None and seg_outputs is not None:
            loss_seg = seg_outputs.loss * self.seg_loss_weight
            loss_seg_dict = {k: v * self.seg_loss_weight for k, v in seg_outputs.loss_dict.items()}
            loss = loss_seg
            loss_dict = {"loss": loss, "loss_seg": loss_seg}
            loss_dict.update(loss_seg_dict)
        elif llm_outputs is not None and seg_outputs is not None:
            loss_llm = llm_outputs.loss * self.llm_loss_weight
            loss_seg = seg_outputs.loss * self.seg_loss_weight
            loss_seg_dict = {k: v * self.seg_loss_weight for k, v in seg_outputs.loss_dict.items()}
            loss = loss_llm + loss_seg
            loss_dict = {"loss": loss, "loss_llm": loss_llm, "loss_seg": loss_seg}
            loss_dict.update(loss_seg_dict)
        else:
            raise ValueError("llm_outputs and seg_outputs are both None")

        return loss_dict

    def state_dict(self, *args, **kwargs):
        state_dict = super().state_dict(*args, **kwargs)
        to_return = OrderedDict()
        # Step 1. visual_encoder
        if self.visual_encoder is not None:
            if self.use_visual_encoder_lora:
                to_return.update(get_peft_model_state_dict(self.visual_encoder, state_dict=state_dict))
            elif not self.freeze_visual_encoder:
                to_return.update({k: v for k, v in state_dict.items() if "visual_encoder." in k})
        # Step 2. segmentor
        if self.segmentor is not None:
            if self.use_segmentor_encoder_lora:
                to_return.update(get_peft_model_state_dict(self.segmentor.encoder, state_dict=state_dict))
            elif not self.freeze_segmentor_encoder:
                to_return.update({k: v for k, v in state_dict.items() if "segmentor.encoder" in k})

            # segmentor other parts except encoder
            to_return.update(
                {k: v for k, v in state_dict.items() if "segmentor" in k and "segmentor.encoder" not in k}
            )
        # Step 3. LLM
        if self.llm is not None:
            if self.use_llm_lora:
                to_return.update(get_peft_model_state_dict(self.llm, state_dict=state_dict))
            elif not self.freeze_llm:
                to_return.update({k: v for k, v in state_dict.items() if "llm." in k})
        # Step 4. Projector
        to_return.update({k: v for k, v in state_dict.items() if "visual_projector." in k})
        to_return.update({k: v for k, v in state_dict.items() if "seg_projector." in k})
        to_return.update({k: v for k, v in state_dict.items() if "llm_projector." in k})
        # Step 5. seg_connector
        to_return.update({k: v for k, v in state_dict.items() if "seg_connector." in k})
        # Step 6. other embeds
        to_return.update({k: v for k, v in state_dict.items() if "bg_embeds." in k})
        to_return.update({k: v for k, v in state_dict.items() if "vgd_embeds." in k})
        # Step 7. vision_sampler
        to_return.update({k: v for k, v in state_dict.items() if "vision_sampler." in k})
        return to_return

    def _parse_lora_config(self, lora_config):
        if isinstance(lora_config, dict) or isinstance(lora_config, Config) or isinstance(lora_config, ConfigDict):
            lora_config = BUILDER.build(lora_config)
        return lora_config

    def _prepare_llm_for_lora(self, lora_config, use_activation_checkpointing=True):
        lora_config = self._parse_lora_config(lora_config)
        self.llm = prepare_model_for_kbit_training(self.llm, use_activation_checkpointing)
        if lora_config.target_modules is None:
            modules = find_all_linear_names(self.llm)
            lora_config.target_modules = modules
        self.llm = get_peft_model(self.llm, lora_config)

    def _prepare_visual_encoder_for_lora(self, lora_config, use_activation_checkpointing=True):
        lora_config = self._parse_lora_config(lora_config)
        if lora_config.target_modules is None:
            modules = find_all_linear_names(self.visual_encoder)
            lora_config.target_modules = modules
        self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)

    def _prepare_segmentor_for_lora(self, lora_config, use_activation_checkpointing=True):
        if self.segmentor is None:
            return
        lora_config = self._parse_lora_config(lora_config)
        if lora_config.target_modules is None:
            modules = find_all_linear_names(self.segmentor.encoder)
            lora_config.target_modules = modules
        self.segmentor = get_peft_model(self.segmentor.encoder, lora_config)

    def gradient_checkpointing_enable(self):
        self.activation_checkpointing_enable()

    def activation_checkpointing_enable(self):
        if self.llm is not None:
            self.llm.gradient_checkpointing_enable()
        if self.visual_encoder is not None:
            self.visual_encoder.gradient_checkpointing_enable()
            self.visual_projector.gradient_checkpointing_enable()
        if self.segmentor is not None:
            self.segmentor.gradient_checkpointing_enable({"use_reentrant": False})
            if hasattr(self, "seg_projector"):
                self.seg_projector.gradient_checkpointing_enable()
            if hasattr(self, "llm_projector"):
                self.llm_projector.gradient_checkpointing_enable()
            if hasattr(self, "seg_connector"):
                self.seg_connector.gradient_checkpointing_enable()

    def gradient_checkpointing_disable(self):
        self.activation_checkpointing_disable()

    def activation_checkpointing_disable(self):
        if self.llm is not None:
            self.llm.gradient_checkpointing_disable()
        if self.visual_encoder is not None:
            self.visual_encoder.gradient_checkpointing_disable()
            self.visual_projector.gradient_checkpointing_disable()
        if self.segmentor is not None:
            self.segmentor.gradient_checkpointing_disable()
            if hasattr(self, "seg_projector"):
                self.seg_projector.gradient_checkpointing_disable()
            if hasattr(self, "llm_projector"):
                self.llm_projector.gradient_checkpointing_disable()
            if hasattr(self, "seg_connector"):
                self.seg_connector.gradient_checkpointing_disable()

    def init_weights(self):
        pass

摘要:大型语言模型在广泛的知识表征方面展现出强大能力,但在像素级感知理解方面存在固有缺陷。尽管分割一切模型在视觉提示驱动的图像分割领域取得了重大进展,但它在多掩码预测和特定类别分割任务中存在明显局限性,且无法在统一模型架构中集成所有分割任务。为应对这些挑战,我们提出X-SAM——一个简化的多模态大语言模型框架,将分割范式从"分割一切"扩展至"任意分割"。具体而言,我们引入了一种新型统一框架,使多模态大语言模型能实现更先进的像素级感知理解。此外,我们提出名为"视觉定位分割"的新任务,该任务通过交互式视觉提示分割所有实例对象,并赋予多模态大语言模型具有视觉定位的像素级解析能力。为实现多数据源的有效训练,我们提出支持跨数据集协同训练的统一训练策略。实验结果表明,X-SAM在广泛的图像分割基准测试中达到了最先进性能,彰显了其在多模态像素级视觉理解方面的卓越效能。

1 引言

随着大语言模型的快速发展以及多模态预训练方法的进步,多模态大语言模型展现了显著的进展。这些模型在广泛的应用中表现出卓越的有效性,包括图像描述、视觉问答和视觉编辑。然而,开发一个真正通用化模型仍然存在一个重大障碍:当前的MLLMs仅限于生成纯文本输出。这一限制对直接处理需要像素级视觉数据理解的任务构成了相当大的挑战,例如图像分割——这是计算机视觉领域中最关键的任务之一。

Segment Anything Model(SAM)代表了一个基础分割模型,它在生成密集分割掩码方面展示了卓越的功效,并启发了各种分割任务的发展,例如高质量分割、匹配任意对象和跟踪任意对象。然而,SAM的架构从根本上受限于其对视觉提示的依赖,这显著限制了其在广泛图像分割任务中的直接适用性,包括通用分割、参考表达式分割和开放词汇分割等。实现一个能够处理各种图像分割任务的统一框架仍然是一个具有挑战性的问题。

在这项工作中,我们介绍了X-SAM,一个创新的框架,它统一了多样化的图像分割任务,将分割范式从"分割任何事物"扩展到了"任何分割"。为实现此目标,我们的方法解决了三个关键的技术挑战:(1)任务形式化:将SAM转变为具有跨任务适用性的通用分割架构。(2)模态增强:增强LLMs以具备多模态输入处理能力。(3)统一框架:开发一种有效促进跨不同领域的全面分割应用的连贯方法。

首先,我们开发了一个统一的MLLM分割架构,它包含一个统一的掩码解码器,能够生成适用于广义图像分割任务的分割掩码。其次,我们扩展了MLLMs的多模态能力,使其不仅能处理文本查询,还能处理视觉查询。具体来说,我们引入了一项称为视觉接地分割的新任务,该任务通过交互式视觉提示来分割图像中的所有实例对象。这项任务将视觉引导模态引入了大语言模型。此外,我们提出了一个统一的输入格式和训练方法,在一个统一框架内重新形式化分割任务,从而优化MLLMs对各种图像分割任务的适应。

如图1和表1所示,我们展示了X-SAM的全面能力,并与其他方法进行了比较。我们提出的框架在处理基于文本查询的任务方面展现出能力,同时也能适应基于视觉查询的任务。此外,X-SAM利用LLMs的推理和生成能力,从而实现了先进的推理分割和接地对话生成分割。

X-SAM与多种多样的数据集进行了协同训练。我们在七个不同的图像分割任务上,对超过二十个分割数据集进行了全面评估,甚至包括图像转换任务。X-SAM在所有图像分割基准测试中均实现了最先进的性能,并为统一的像素级图像理解建立了一个强大的新基准。

总之,我们的贡献如下:

  • 我们介绍了X-SAM,一个新颖的统一框架,它将分割范式从"分割任何事物"扩展到了"任何分割"。我们的方法将多样化的图像分割任务形式化为标准化的分割格式。
  • 我们提出了一个新的图像分割基准,视觉接地分割,它为MLLMs提供视觉接地提示,以分割图像中的实例对象。该基准引入了用户友好的输入来接地分割对象,并引导MLLMs输出分割掩码。
  • 我们提出了一种统一的多阶段训练策略,用多种数据集协同训练X-SAM,并在超过二十个图像分割基准上进行了广泛评估,在所有基准上都实现了最先进的性能。这为MLLMs中统一的像素级感知理解建立了一个新的强大基准。

2 相关工作

多模态大语言模型 多模态学习经历了从早期专注于特定任务的融合和特征提取的模型,发展到利用大语言模型来处理广义的、经过指令调优的多任务基准。LLaVA 引入了视觉特征标记化,启发了视觉表征、专用视觉扩展以及语言引导分割等领域的进展。然而,大多数进展仍局限于特定任务。据我们所知,我们是首个成功实现全面方法的研究,为图像分割开辟了新的方向。

多模态接地分割 近期研究探索了视觉领域的视觉初始化方法,包括可学习标记、掩码视觉建模和视觉提示编码器。SAM 及其扩展将视觉接地信号引入分割模型,极大地提升了性能。交互式分割进一步增强了 MLLMs 中用户引导的分割能力。然而,现有方法无法自由地将接地输入视为文本输入来处理分割任务。为解决此问题,我们提出了视觉接地分割,以实现更多样化的多模态接地分割。

统一分割模型 视觉 Transformer 推动了通用分割的发展,近期研究工作开发了端到端的掩码分类框架,在各种应用中超越了早期模型。研究已扩展到开放世界和开放词汇分割,以及面向多任务的统一架构。然而,大多数方法仅专注于视觉分割,缺乏 MLLMs 中具备的交互式文本和视觉提示能力。为解决此问题,我们将 SAM 与 MLLMs 相结合,将 SAM 从"分割任何事物"扩展至"任何分割",并引入了一个能适应所有图像分割任务的统一框架,从而建立了一个新的强大基准。

3 方法

为实现统一的图像分割,我们提出了X-SAM,一种新颖的多模态分割MLLM。我们设计了一种通用的输入格式和一个统一的框架,将不同的分割任务集成到单个模型中。此外,我们引入了一种创新的训练策略,使SAM能够处理任何分割任务。以下各节将详细说明我们的方法。

3.1 形式化

统一分割模型的开发面临着诸多挑战,这些挑战源于分割任务的多样性以及输入格式的可变性。为解决这些问题,我们引入了一种通用的输入格式,旨在支持广泛的图像分割任务,为X-SAM的统一框架奠定基础。我们将输入格式划分为两个主要类别:文本查询输入视觉查询输入。文本查询输入仅包含来自用户请求的语言提示,而视觉查询输入则整合了用户提供的语言提示和视觉提示。

文本查询输入。 大多数现有的图像分割任务可以被概念化为文本查询输入,包括通用分割、参考表达式分割、开放词汇分割、GCG分割和推理分割。文本查询输入封装了用户的请求以及要分割的特定类别或对象,这些信息可能嵌入在用户的提示中或由大语言模型生成。为了促进GCG分割任务,受GLaMM的启发,我们在分词器中加入了两个特殊的短语标记<p></p>,分别表示短语的开始和结束。对于通用分割和GCG分割中的每个类别、参考表达式分割中的短语或推理分割中的句子,其格式均标准化为"<p>类别/短语/句子</p>"。具体来说,<p></p>标记不仅在输入标记中被编码,也在输出标记中生成,确保了不同任务间的一致性。此外,对于输出,我们借鉴的方法,在分词器中引入了一个特殊标记<SEG>来表示分割结果。

视觉查询输入。 除了文本查询输入,某些任务需要视觉查询输入,例如交互式分割和本文提出的视觉接地分割。与文本查询输入不同,视觉查询输入包含了来自用户的视觉提示,其形式可以是点、涂鸦线、框或掩码。为了表示视觉提示,我们在输入格式中使用了一个专用标记<region>。与文本查询输入类似,视觉提示的格式为"<p><region></p>",分割输出同样由<SEG>标记指示。<region>标记充当视觉提示的占位符,并将被分割编码器提取的区域特征所替换。

统一形式化。 <p></p>标记之间的潜在语言嵌入被用作分割解码器的条件嵌入来计算分类分数。基于此形式化,我们实现了适用于所有图像分割任务的统一框架。给定一个输入图像 Xv∈RH×W×3Xv​∈RH×W×3 和一个语言指令 Xq∈RP×1Xq​∈RP×1,模型将图像和语言指令作为输入,并输出一个语言响应 Yq∈RL×1Yq​∈RL×1 和一个分割掩码 Ym∈RH×WYm​∈RH×W。这里,PP 是输入文本标记的长度,LL 是输入和输出文本标记的总长度。HH 和 WW 分别表示图像的高度和宽度。详细的输入格式示例可见图1 (a) 和 (b)。

3.2 架构

在本节中,我们提出了X-SAM,一个用于"任何分割"的统一分割MLLM。如图2所示,它包含双编码器、双投影器、一个LLM、一个分割连接器和一个分割解码器。

图2:X-SAM 概述。X-SAM 包含双编码器、双投影器、一个语言模型、一个分割连接器和一个分割解码器。双编码器处理图像并将特征投影到与文本嵌入维度匹配,然后与标记化文本一同输入语言模型,以实现指令引导的理解。SAM特征被连接到分割解码器,该解码器利用LLM的<SEG>标记来生成分割掩码。

双编码器。 X-SAM中有两个编码器:一个图像编码器和一个分割编码器。图像编码器 ff 用于提取全局图像特征 Zv=f(Xv)Zv​=f(Xv​),而分割编码器 gg 则提取细粒度的图像特征 Zs=g(Xv)Zs​=g(Xv​)。来自图像编码器的特征是全局的,有利于图像理解任务;而来自分割编码器的特征是细粒度的,有利于图像分割任务。我们采用SigLIP2-so400m作为图像编码器,并采用SAM-L作为分割编码器。

双投影器。 为了增强LLM对图像的理解,我们在将特征传递给LLM之前,拼接了来自图像编码器和分割编码器的特征。具体来说,来自分割编码器的特征过大,无法被LLM直接处理,因此我们利用像素重排操作来减少其空间尺寸。然后我们通过一个MLP投影器 WsWs​ 将降维后的特征投影到语言嵌入空间 HqHq​。对于来自图像编码器的特征,我们通过一个MLP投影器 WiWi​ 直接将其投影到语言嵌入空间,即 Hv=Wi⋅ZvHv​=Wi​⋅Zv​ 和 Hs=Ws⋅ZsHs​=Ws​⋅Zs​。然后,我们将来自双投影器的特征和语言嵌入进行拼接,并将它们输入到LLM fϕfϕ​ 中。

分割连接器。 对于图像分割任务,细粒度的多尺度特征对于分割解码器准确预测分割掩码至关重要。SAM中分割编码器的输出是单尺度的(1/16),空间分辨率较低。为了获得多尺度特征,我们设计了一个分割连接器 gcgc​,以桥接分割编码器和解码器。如图所示,该连接器包含多个上采样和卷积层,用于从分割编码器输出的单尺度特征生成多尺度特征。

如图3所示,我们采用尺度为0.5的像素重排操作进行块合并,将编码器中最后一个特征图的空间尺寸缩减到更小的尺度(1/32)。同时,我们采用尺度为2.0的像素重排操作进行块扩展,将最后一个特征图的空间尺寸增大到更大尺度(1/8),从而为分割解码器生成多尺度特征。

分割解码器。 Segment Anything Model可以根据输入的文本或视觉提示分割单个对象,但无法在单次推理中分割所有对象。为了能一次性分割所有对象,我们遵循的方法,用一个新的解码器替换了其原始的分割解码器。分割解码器 gψgψ​ 根据输入的潜在嵌入 EiEi​ 或输出的潜在嵌入 EoEo​、多尺度分割特征 FcFc​ 以及一组掩码查询标记加上<SEG>标记嵌入(它桥接了LLM输出与分割解码器)来预测掩码及其类别概率。值得注意的是,我们引入了一个潜在的背景嵌入来表示所有任务的"忽略"类别,从而用一个模型统一了所有图像分割任务。

3.3 训练

为了提高在各种图像分割任务上的性能,我们提出了一种新颖的多阶段训练策略。该训练策略包含三个阶段:分割器微调、对齐预训练和混合微调。

阶段1:分割器微调。 由于分割解码器被重新设计,我们需要训练分割器以适应在单次前向传播中分割所有对象。我们遵循的训练流程,在流行的COCO-Panoptic数据集上训练模型。为了在训练期间实现更快的收敛,我们解冻了分割器中的所有参数,同时以较低的学习率训练分割编码器。训练目标 LsegLseg​ 与中的相同,定义为分类损失 LclsLcls​、掩码损失 LmaskLmask​ 和dice损失 LdiceLdice​ 之和:

阶段2:对齐预训练。 为了对齐语言嵌入和视觉嵌入,我们按照的方法,在LLaVA-558K数据集上执行对齐预训练。我们保持双编码器和LLM参数冻结,仅训练双投影器。通过这种方式,图像嵌入和分割嵌入可以与预训练的LLM词嵌入对齐。对齐预训练的训练目标是一个自回归损失 LregressiveLregressive​:

其中 Xq=[x1,x2,...,xp]∈RP×DXq​=[x1​,x2​,...,xp​]∈RP×D 是输入序列,Yq=[y1,y2,...,yl]∈RL×DYq​=[y1​,y2​,...,yl​]∈RL×D 是输出序列,其中 L=P+NL=P+N 表示输出序列的长度,DD 表示LLM的隐藏大小。θθ 是LLM中的可训练参数,并且我们仅计算生成文本的损失。

阶段3:混合微调。 X-SAM以端到端的方式在多个任务的数据集上进行协同训练。对于图像对话任务,我们采用MLLM训练中常见的自回归损失 LregressiveLregressive​。对于分割任务,我们不仅使用分割器训练中的分割损失,还在训练目标中加入了自回归损失。得益于统一的形式化和简单的训练目标,跨不同任务的端到端混合微调可以在一个统一框架内执行。混合微调的训练目标可以表述为:

4 实验

4.1 实验设置

数据集与任务。 对于分割器微调,我们在COCO-Panoptic数据集上进行训练。对于对齐预训练,我们使用LLaVA-558K数据集。对于端到端混合微调,我们将一个图像对话数据集和五种类型的图像分割数据集纳入训练过程。为了平衡这些不同数据集之间的训练数据,我们将训练周期设置为1,并使用数据集平衡重采样来调整不同数据集的重采样率。训练完成后,X-SAM能够执行多种任务,包括图像对话、通用分割、参考表达式分割、推理分割、GCG分割、交互式分割和VGD分割。此外,X-SAM支持开放词汇分割,使其能够分割输入提示所定义的所有对象,即使是以前从未见过的对象。请注意,COCO-VGD是我们提出的基于COCO2017数据集构建的VGD分割数据集。数据集的详细信息在附录A.1中给出。

评估指标。 我们进行了广泛的实验来评估X-SAM的性能。对于通用分割和开放词汇分割,我们分别使用PQ、mIoU和mAP作为全景分割、语义分割和实例分割的主要指标。对于参考表达式分割和推理分割,我们采用cIoU和gIoU作为指标。对于GCG分割,我们使用METEOR、CIDEr、AP50和mIoU作为指标。对于交互式分割,我们使用mIoU和cIoU。对于VGD分割,我们使用AP和AP50。对于图像对话,我们采用常见MLLM基准测试的得分作为主要指标。

实现细节。 我们采用XTuner代码库进行训练和评估。在分割器微调期间,我们训练所有参数,设置批大小为64,SAM编码器的学习率为1e-5,其他参数的学习率为1e-4。训练周期数设置为36。对于对齐预训练,我们仅训练双投影器参数,批大小为256,学习率为1e-3,训练一个周期。对于端到端混合微调,我们训练所有参数,设置批大小为64,双编码器的学习率为4e-6,其他参数的学习率为4e-5,训练一个周期。所有训练均在16块A100 GPU上进行。对于图像对话评估,我们使用VLMEvalKit代码库来评估MLLM基准测试的性能。对于分割任务评估,我们遵循相应论文和代码库中描述的设置。更多实现细节在附录A.3中提供。

4.2 主要结果

我们在七个分割任务上进行了广泛评估,包括通用分割、开放词汇分割、参考表达式分割、推理分割、GCG分割、交互式分割和VGD分割。

总体表现。 在表2中,我们将X-SAM与当前的分割专用模型和MLLMs进行了比较。X-SAM展示了最全面的能力。它在通用分割上取得了与最先进方法相媲美的性能,并在其他基准测试上使用单一模型取得了最佳性能。X-SAM为图像分割基准测试树立了新的最先进记录。每个任务的详细结果讨论如下。

参考表达式分割。 我们在RefCOCO、RefCOCO+和RefCOCOg上评估X-SAM,结果如表3所示。在RefCOCO、RefCOCO+和RefCOCOg的验证集上,X-SAM分别比PSALM高出1.5% cIoU、5.1% cIoU和10.0% cIoU。与Sa2VA-8B相比,X-SAM以更小的模型规模取得了更好的结果,在RefCOCO、RefCOCO+和RefCOCOg上分别实现了3.5% cIoU、1.8% cIoU和5.1% cIoU的性能提升。

GCG分割。 接地对话生成需要细致的图像和像素级理解,要求MLLMs将描述的对象与其分割掩码联系起来。如表4所示,与先前的方法相比,X-SAM实现了显著的性能提升,并在Val集和Test集上都获得了最佳结果。在图像级理解方面,X-SAM在Val集上比GLaMM高出0.2% METEOR和3.2% CIDEr,在Test集上高出0.5% METEOR和4.8% CIDEr。在像素级理解方面,X-SAM在Val集上比OMG-LLaVA高出3.3% AP和3.9% mIoU,在Test集上高出4.3% AP和4.3% mIoU。

VGD分割。 视觉接地分割需要视觉查询理解,要求MLLMs理解视觉模态并分割所有相关实例。表5展示了VGD分割的结果。由于VGD分割是我们新提出的任务,我们按照X-SAM的设置评估了PSALM。在点、涂鸦线、框和掩码视觉提示上,X-SAM分别比PSALM高出45.9% AP、45.9% AP、45.8% AP和47.4% AP。

其他分割和对话基准测试的更多结果和讨论在附录A.5中提供。

4.3 消融实验

我们对混合微调、双编码器、多阶段训练和分割器架构进行了消融研究,由于篇幅限制,仅展示部分基准测试结果。

混合微调。 我们消融研究了混合微调对X-SAM性能的影响。如表6所示,混合微调提高了在域外COCO基准测试上的性能,证明了X-SAM强大的分割能力——例如,在A150-OV上AP提高了6.0%,在Reason-Val上gIoU提高了8.9%。然而,由于多源训练中平衡性能的挑战,它导致COCO-Pan上的PQ下降了0.8%。

双编码器。 我们消融研究了X-SAM中双编码器的设计。如表7所示,带有SAM或Swin编码器的双编码器均有益于VGD分割,在COCO-VGD上分别达到了7.2% AP和7.9% AP。此外,带有SAM编码器的双编码器在GCG-Val和A150-OV上持续提升性能,而缺乏强大分割能力的Swin编码器仅在A150-OV上提供小幅改进,甚至对GCG-Val产生负面影响。

多阶段训练。 我们消融研究了多阶段训练策略的影响。如表8所示,S1分割器微调阶段提升了分割能力,在COCO-Pan和A150-OV数据集上分别带来了9.3% PQ和1.5% AP的显著提升。同时,S2对齐预训练阶段增强了图像理解能力,在Conv.-MMB上额外贡献了2.1%的准确率。通过整合这些阶段,X-SAM在图像分割和理解方面展现出强劲的进步,确立了其在处理复杂视觉任务方面的有效性。

分割器架构。 我们通过进行12个周期的分割器微调来消融研究分割器架构的影响。如表9所示,M2F解码器带来了巨大的改进,提升了9.2% PQ,这归功于M2F的有效设计。卷积连接器的性能优于MLP连接器,因为卷积的空间感知能力有利于分割,而多尺度特征通过提供更多样化的尺度特征进一步提升了性能(10.7% PQ)。

更多消融实验结果可在附录A.6中找到。

5 结论

在这项工作中,我们提出了X-SAM,一个统一的分割MLLM,它将分割范式从"分割任何事物"扩展到"任何分割",将所有图像分割任务集成到一个单一模型中。我们的方法可以处理MLLMs中的各种多模态输入,包括文本和视觉查询。此外,为了使MLLMs具备视觉接地感知能力,我们引入了一个新的分割任务——视觉接地分割,进一步扩展了统一分割模型的能力。我们在所有图像分割任务上进行了广泛的实验,X-SAM使用单一模型在每个任务上都达到了最先进的性能。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 博主简介
  • 原理介绍
  • X-SAM/xsam/xsam/model/xsam.py
    • 1 引言
    • 2 相关工作
    • 3 方法
    • 3.1 形式化
    • 3.3 训练
    • 4 实验
    • 4.1 实验设置
    • 4.2 主要结果
    • 5 结论
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档