본문 바로가기
Papers & Research Notes/Paper & Code Review

Recent Multimodal RAG Papers (ColPali, SV-RAG, URaG, MetaEmbed)

by Air’s Big Data 2026. 2. 14.
 

Multimodal Retrieval-Augmented Generation (RAG) has evolved rapidly over the last two years. Although these systems are often discussed under the same umbrella, they operate at different architectural levels, retrieval units, and deployment regimes.

This post examines four influential systems:


Where Does Retrieval Happen?

A principled way to distinguish these methods is to consider where retrieval occurs in the pipeline.

Method Core idea Retrieval Scope Where Retrieval Happens ColPali Represent each page as patch-level vectors and score query–page relevance via late interaction (MaxSim) for fine-grained OCR-free page retrieval. Page images from a corpus External multi-vector retriever SV-RAG Add retrieval and QA adapters (e.g., LoRA) so the MLLM performs in-context page selection and answering without a separate retriever/index. Pages within a long document Inside the MLLM via adapters URaG Use early-layer hidden states to score pages and prune irrelevant pages during the forward pass, reducing compute and interference for long-doc reasoning. Pages within a long document Inside early Transformer layers MetaEmbed Use learnable meta tokens to produce nested compact multi-vector embeddings, enabling explicit test-time scaling of retrieval cost vs accuracy. Generic multimodal items External compact multi-vector retriever


ColPali

ColPali adopts a late-interaction multi-vector retrieval paradigm. Its main idea is simple but powerful: do not compress a document page into a single vector. Instead, keep a sequence of token/patch embeddings and use a ColBERT-style MaxSim scoring function that aligns query tokens to the best-matching page tokens.

Multi-vector embeddings from the VLM backbone (PaliGemma to 128-d tokens)

ColPali reuses a vision-language model (e.g., PaliGemma) but turns it into a retriever by taking the last hidden states, projecting to a fixed embedding dimension (128), L2-normalizing each token, masking padding tokens, and optionally keeping only image-token embeddings at inference:

# modeling_colpali.py
outputs = self.model(*args, output_hidden_states=True, **kwargs)
last_hidden_states = outputs.hidden_states[-1]                 # (B, L, H)
proj = self.custom_text_proj(last_hidden_states)             # (B, L, 128)

proj = proj / proj.norm(dim=-1, keepdim=True)                 # L2 normalization
proj = proj * kwargs["attention_mask"].unsqueeze(-1)          # zero-out padding

# optional: keep only image-token embeddings at inference
if "pixel_values" in kwargs and self.mask_non_image_embeddings:
    image_mask = (kwargs["input_ids"] == self.config.image_token_index).unsqueeze(-1)
    proj = proj * image_mask

return proj

Instead of a single page vector, ColPali outputs a token sequence embedding (B, L, 128) where L includes image patches and possibly text tokens. This is the foundation that enables late interaction.

OCR-free page processing

A document page is treated as an image input to the VLM processor. ColPali uses a fixed visual prompt:

# processing_colpali.py
visual_prompt_prefix: ClassVar[str] = "<image><bos>Describe the image."

batch_doc = self(
    text=[self.visual_prompt_prefix] * len(images),
    images=images,
    return_tensors="pt",
    padding="longest",
)
return batch_doc

This makes retrieval OCR-free: the page image is encoded directly through the multimodal model, producing patch/token embeddings that preserve layout and visual structure.

Query augmentation: reasoning buffer tokens

ColPali optionally appends a suffix of repeated tokens (by default, the pad token) to every query. This is described in code as a reasoning buffer:

# processing_utils.py
if suffix is None:
    suffix = self.query_augmentation_token * 10

texts = [self.query_prefix + text + suffix for text in texts]
return self.process_texts(texts=texts)

# processing_colpali.py
@property
def query_augmentation_token(self) -> str:
    return self.tokenizer.pad_token  # reasoning buffers during inference

Appending extra tokens gives the model more slots for late interaction, often improving robustness for complex queries over dense pages.

MaxSim scoring (late interaction)

ColPali does token-level alignment: (1) compute all dot products between query and page tokens; (2) for each query token, keep the maximum similarity across page tokens; (3) sum over query tokens. In code:

# processing_utils.py
scores_batch.append(
    torch.einsum("bnd,csd->bcns", qs_batch, ps_batch)  # (Bq, Bp, Nq, Np)
    .max(dim=3)[0]                                     # max over passage tokens -> (Bq, Bp, Nq)
    .sum(dim=2)                                        # sum over query tokens -> (Bq, Bp)
)

Because tensors can be large, scores are computed in blocks over queries and passages:

for i in range(0, len(qs), batch_size):
    qs_batch = pad_sequence(qs[i:i+batch_size], batch_first=True, padding_value=0).to(device)
    for j in range(0, len(ps), batch_size):
        ps_batch = pad_sequence(ps[j:j+batch_size], batch_first=True, padding_value=0).to(device)
        ...

Single-vector retrieval uses score = q · p. ColPali uses score(q,p) = sum over query tokens t of max over passage tokens s of <q_t, p_s>. Each query token can find the best-matching patch on the page, making it robust to layout shifts and strong for tables, charts, and multi-column pages.

Training: InfoNCE over late-interaction scores

The training loss mirrors the same late-interaction scoring, then applies cross-entropy where the positive document for query i is at index i in the batch:


# late_interaction_losses.py
raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings)
scores = raw.amax(dim=3).sum(dim=2)   # MaxSim then sum

return self.ce_loss(scores / self.temperature, pos_idx)

ColPali optionally normalizes scores by query length:

lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
scores = scores / lengths.unsqueeze(1)

This makes ColPali a retriever-first model: it directly optimizes late-interaction retrieval quality using in-batch negatives, in the same spirit as ColBERT extended to page images.

Takeaway. ColPali is an OCR-free, page-image retriever that keeps multi-vector token/patch embeddings from a VLM backbone and uses ColBERT-style MaxSim late interaction to score query–page relevance. Instead of collapsing a page into a single embedding, it computes a token-level similarity matrix, selects the best-matching page token for each query token, and sums these maxima to form the final score, yielding strong retrieval accuracy for visually complex pages at the cost of larger indexes and higher retrieval-time compute.


SV-RAG

SV-RAG implements retrieval inside the model. It introduces two adapters to an MLLM: one for evidence page retrieval and one for question answering. The model thus functions as a self-contained retriever within a long document, without requiring an external index. The design is modular and provides long-document capability without a full retrieval stack.

Col-Projection and Late Interaction

SV-RAG projects the LLM’s last hidden states into a low-dimensional retrieval space (128-d) and uses ColBERT-style MaxSim for relevance scoring:

# modeling_colInternvl2_4b.py
last_hidden_states = outputs.hidden_states[-1]
proj = self.custom_text_proj(last_hidden_states)
proj = proj / proj.norm(dim=-1, keepdim=True)
return proj
# processing_utils.py — score_multi_vector
scores_batch.append(
    torch.einsum("bnd,csd->bcns", qs_batch, ps_batch)
    .max(dim=3)[0]
    .sum(dim=2)
)

Dual LoRA: Retrieval vs. QA

Retrieval uses the ColPhi/ColInternvl2 model with LoRA. For QA, the retrieval LoRA is disabled so the same backbone generates answers from the base weights (or a separate QA LoRA when fine-tuned):

# model_util.py
def disable_lora_if_present(self):
    if isinstance(self.model, PeftModel):
        self.lm_model.disable_adapter()   # use base weights for QA

def ask(self, question_string, img_dir_list):
    self.disable_lora_if_present()
    # ... generate with lm_model ...

Retrieval and QA Pipeline

# model_util.py — SVRAG_InternVL2
def retrieve(self, query_list, image_list):
    text_embeddings = self.process_text(query_list)
    image_embeddings = self.process_image(image_list)
    similarity_score = self.processor.score_multi_vector(text_embeddings, image_embeddings)
    values, top_indices = torch.tensor(similarity_score).sort(descending=True)
    return values, top_indices
# test_sv_rag.py — end-to-end
_, retrieved_indices = model.retrieve(query_list=question_list, image_list=image_list)
model.disable_lora_if_present()
retrieved_index_top = retrieved_indices[0][:k]
true_img_list = [x for i, x in enumerate(image_list) if i in retrieved_index_top]
answer_model = model.ask(question, true_img_list)

Retrieval is a learned multi-vector behavior inside the VLM embedding space. No external vector DB is required.


URaG

URaG transforms early Transformer layers into a differentiable evidence selector that prunes irrelevant pages during the forward pass. This is not external retrieval or an adapter scoring head; it is structural modification of the transformer computation flow. URaG decides what to continue computing on during generation, rather than deciding what to read before generation.

The Three Parameters That Control the Mechanism

layer_for_retrieval: Evidence scoring happens at an intermediate layer (layer 6). URaG runs the model up to layer L, uses hidden states to compute page relevance, prunes pages, then continues the forward pass only with selected pages. This is architectural early-exit pruning.

remain_pages: After scoring, URaG keeps only top-k pages. Subsequent layers attend only to retained tokens, reducing attention cost, memory, and noise from irrelevant context. Pruning happens inside the same forward pass.

end_question_position: URaG must distinguish question tokens from page tokens. It isolates question representations from early hidden states to enable question-aware page scoring.

Core Forward Logic

# modeling_urag.py — conceptual structure
def urag_forward(inputs):
    # Step 1: Run transformer up to retrieval layer
    hidden_states = transformer_layers[:layer_for_retrieval](inputs)
    # Step 2: Project to retrieval space, extract question & page representations
    projed_features = proj_layer(retrieval_hidden_states)
    projed_features = projed_features / projed_features.norm(dim=-1, keepdim=True)
    question_feat = projed_features[..., :end_question_position, :]
    page_reps = extract_page_segments(projed_features)
    # Step 3: Score pages via late interaction (ColBERT-style)
    scores = [late_interaction_scorev2(question_feat, page_rep) for page_rep in page_reps]
    # Step 4: Select top-k, prune tokens belonging to other pages
    drop_im_ids = indices_of_pages_below_topk(scores, k=remain_pages)
    mask[start:end] = False for idx in drop_im_ids
    hidden_states = hidden_states[:, mask, :]
    attention_mask, position_ids = apply_mask(...)
    # Step 5: Continue forward pass with pruned context
    outputs = transformer_layers[layer_for_retrieval:](hidden_states)
    return lm_head(outputs)

Late Interaction Scoring

URaG uses the same ColBERT-style MaxSim as ColPali:

# modeling_urag.py — late_interaction_scorev2
def late_interaction_scorev2(E_q, E_v):
    E_q, E_v = E_q.float().unsqueeze(0), E_v.float().unsqueeze(0)
    score = torch.einsum("bnd,bsd->bns", E_q, E_v).max(dim=2)[0].sum(dim=1)[0]
    return score

Pruning in Code

# modeling_urag.py — inference
if late_interaction_scores.size(0) > self.remain_pages:
    topk_scores, _ = torch.topk(late_interaction_scores, self.remain_pages)
    drop_im_ids = torch.where(late_interaction_scores < topk_scores[-1])[0]
    mask = torch.ones(hidden_states.size(1), dtype=torch.bool, device=hidden_states.device)
    for idx in drop_im_ids:
        start, end = im_token_scope[idx]
        mask[start:end] = False
    hidden_states = hidden_states[:, mask, :]
    attention_mask = attention_mask[:, mask]
    position_ids = position_ids[:, :, mask]

Demo Usage

model = URaG_ForConditionalGeneration.from_pretrained("shi-yx/URaG-3B", trust_remote_code=True).cuda()
processor = AutoProcessor.from_pretrained("shi-yx/URaG-3B", trust_remote_code=True)
model.remain_pages = 5
model.layer_for_retrieval = 6
model.end_question_position = -4 - len(tokenizer(VQA_prompt)['input_ids'])
inputs = processor(text=query, images=page_images, return_tensors="pt").to("cuda")
out = model.generate(**inputs, max_new_tokens=512)

What Makes URaG Different

  • Not external retrieval: No vector DB, no separate retriever, no separate index.
  • Not adapter-based scoring only: Retrieval is embedded in the forward flow, not a side head.
  • Structural computation control: URaG modifies which tokens survive past layer L and which pages participate in deeper attention, directly reducing quadratic attention cost.

Retrieval and generation are jointly optimized; pruning is differentiable during training. This yields both efficiency gains (44–56% computational reduction) and reduced interference from irrelevant material.


MetaEmbed

MetaEmbed belongs to the late-interaction multi-vector family alongside ColPali but employs a different strategy. It appends learnable Meta Tokens whose hidden states serve as compact multi-vector embeddings. At serving time, the number of meta-token vectors used can be selected to trade retrieval quality for cost. It is designed for deployments where explicit cost control is required. Compared to ColPali’s dense fine-grained vectors, MetaEmbed offers compact, scalable representations with adjustable quality-efficiency trade-offs.

Meta Tokens

Instead of using all patch tokens (like ColPali) or a single [CLS]-style vector, MetaEmbed appends M learnable Meta Tokens to the input sequence. After the transformer forward pass, the hidden states corresponding to these meta tokens are extracted and become the multi-vector retrieval embedding. Each item is represented by E(x) = {e_1, …, e_M}, with each e_i a d-dimensional vector. These meta-token vectors summarize the document while remaining much smaller than full token-level embeddings.

Late Interaction and Efficiency

MetaEmbed uses MaxSim-style late interaction: Score(Q, D) = sum_i max_j q_i^T d_j. This preserves fine-grained alignment but only across the compact meta-token set. Scoring complexity becomes O(|Q| x M) instead of O(|Q| x |patch tokens|), which is the main efficiency gain.

Nested (Matryoshka) Multi-Vector Design

Meta-token vectors are trained in a nested fashion: the first vector captures the most important information, the first two capture more, and so on up to the full M. At test time, you can truncate to E_k(x) = {e_1, …, e_k} for k <= M without retraining. This enables test-time retrieval budget control.

Comparison

  • vs ColPali: ColPali uses full token/patch embeddings (many vectors); MetaEmbed uses a small fixed number of learned meta vectors.
  • vs SV-RAG: SV-RAG learns retrieval inside the MLLM via adapters; MetaEmbed is a standalone external retriever.
  • vs URaG: URaG prunes pages inside the forward pass; MetaEmbed scales retrieval cost before generation.

Method Selection by Deployment Regime

The choice of method depends on system constraints:

  • When the objective is fine-grained page-level retrieval accuracy, ColPali is particularly suitable due to its late-interaction multi-vector design.
  • When the goal is to endow an MLLM with internal retrieval capability without external indexing, SV-RAG provides a principled adapter-based solution.
  • When long-document inference is computationally prohibitive, URaG’s in-model evidence localization and pruning strategy offers efficiency gains.
  • When deployment constraints demand explicit control over retrieval cost at inference time, MetaEmbed supports adjustable multi-vector representations.

Thank you for reading!

 

댓글