Skip to content

Commit a39fe5d

Browse files
feat: batch cold account loads in light client
1 parent b07fd20 commit a39fe5d

3 files changed

Lines changed: 126 additions & 49 deletions

File tree

sdk-libs/client/src/indexer/photon_indexer.rs

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,17 +1142,16 @@ impl Indexer for PhotonIndexer {
11421142
.value
11431143
.iter()
11441144
.map(|x| {
1145-
let mut proof_vec = x.proof.clone();
1146-
if proof_vec.len() < STATE_MERKLE_TREE_CANOPY_DEPTH {
1145+
if x.proof.len() < STATE_MERKLE_TREE_CANOPY_DEPTH {
11471146
return Err(IndexerError::InvalidParameters(format!(
11481147
"Merkle proof length ({}) is less than canopy depth ({})",
1149-
proof_vec.len(),
1148+
x.proof.len(),
11501149
STATE_MERKLE_TREE_CANOPY_DEPTH,
11511150
)));
11521151
}
1153-
proof_vec.truncate(proof_vec.len() - STATE_MERKLE_TREE_CANOPY_DEPTH);
1152+
let proof_len = x.proof.len() - STATE_MERKLE_TREE_CANOPY_DEPTH;
11541153

1155-
let proof = proof_vec
1154+
let proof = x.proof[..proof_len]
11561155
.iter()
11571156
.map(|s| Hash::from_base58(s))
11581157
.collect::<Result<Vec<[u8; 32]>, IndexerError>>()
@@ -1703,15 +1702,13 @@ impl Indexer for PhotonIndexer {
17031702

17041703
async fn get_subtrees(
17051704
&self,
1706-
_merkle_tree_pubkey: [u8; 32],
1705+
merkle_tree_pubkey: [u8; 32],
17071706
_config: Option<IndexerRpcConfig>,
17081707
) -> Result<Response<Items<[u8; 32]>>, IndexerError> {
1709-
#[cfg(not(feature = "v2"))]
1710-
unimplemented!();
1711-
#[cfg(feature = "v2")]
1712-
{
1713-
todo!();
1714-
}
1708+
Err(IndexerError::NotImplemented(format!(
1709+
"PhotonIndexer::get_subtrees is not implemented for merkle tree {}",
1710+
solana_pubkey::Pubkey::new_from_array(merkle_tree_pubkey)
1711+
)))
17151712
}
17161713
}
17171714

sdk-libs/client/src/indexer/types/queue.rs

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,14 @@ pub struct AddressQueueData {
6666
}
6767

6868
impl AddressQueueData {
69+
const ADDRESS_TREE_HEIGHT: usize = 40;
70+
6971
/// Reconstruct a merkle proof for a given low_element_index from the deduplicated nodes.
7072
pub fn reconstruct_proof<const HEIGHT: usize>(
7173
&self,
7274
address_idx: usize,
7375
) -> Result<[[u8; 32]; HEIGHT], IndexerError> {
76+
self.validate_proof_height::<HEIGHT>()?;
7477
let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| {
7578
IndexerError::MissingResult {
7679
context: "reconstruct_proof".to_string(),
@@ -126,6 +129,7 @@ impl AddressQueueData {
126129
&self,
127130
address_range: std::ops::Range<usize>,
128131
) -> Result<Vec<[[u8; 32]; HEIGHT]>, IndexerError> {
132+
self.validate_proof_height::<HEIGHT>()?;
129133
let node_lookup = self.build_node_lookup();
130134
let mut proofs = Vec::with_capacity(address_range.len());
131135

@@ -140,23 +144,24 @@ impl AddressQueueData {
140144
pub fn reconstruct_all_proofs<const HEIGHT: usize>(
141145
&self,
142146
) -> Result<Vec<[[u8; 32]; HEIGHT]>, IndexerError> {
147+
self.validate_proof_height::<HEIGHT>()?;
143148
self.reconstruct_proofs::<HEIGHT>(0..self.addresses.len())
144149
}
145150

146151
fn build_node_lookup(&self) -> HashMap<u64, usize> {
147-
self.nodes
148-
.iter()
149-
.copied()
150-
.enumerate()
151-
.map(|(idx, node)| (node, idx))
152-
.collect()
152+
let mut lookup = HashMap::with_capacity(self.nodes.len());
153+
for (idx, node) in self.nodes.iter().copied().enumerate() {
154+
lookup.entry(node).or_insert(idx);
155+
}
156+
lookup
153157
}
154158

155159
fn reconstruct_proof_with_lookup<const HEIGHT: usize>(
156160
&self,
157161
address_idx: usize,
158162
node_lookup: &HashMap<u64, usize>,
159163
) -> Result<[[u8; 32]; HEIGHT], IndexerError> {
164+
self.validate_proof_height::<HEIGHT>()?;
160165
let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| {
161166
IndexerError::MissingResult {
162167
context: "reconstruct_proof".to_string(),
@@ -209,6 +214,18 @@ impl AddressQueueData {
209214
fn encode_node_index(level: usize, position: u64) -> u64 {
210215
((level as u64) << 56) | position
211216
}
217+
218+
fn validate_proof_height<const HEIGHT: usize>(&self) -> Result<(), IndexerError> {
219+
if HEIGHT == Self::ADDRESS_TREE_HEIGHT {
220+
return Ok(());
221+
}
222+
223+
Err(IndexerError::InvalidParameters(format!(
224+
"address queue proofs require HEIGHT={} but got HEIGHT={}",
225+
Self::ADDRESS_TREE_HEIGHT,
226+
HEIGHT
227+
)))
228+
}
212229
}
213230

214231
#[cfg(test)]

sdk-libs/client/src/interface/load_accounts.rs

Lines changed: 94 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ pub enum LoadAccountsError {
5353
#[error("Cold PDA at index {index} (pubkey {pubkey}) missing data")]
5454
MissingPdaCompressed { index: usize, pubkey: Pubkey },
5555

56+
#[error("Cold PDA (pubkey {pubkey}) missing data")]
57+
MissingPdaCompressedData { pubkey: Pubkey },
58+
5659
#[error("Cold ATA at index {index} (pubkey {pubkey}) missing data")]
5760
MissingAtaCompressed { index: usize, pubkey: Pubkey },
5861

@@ -67,6 +70,7 @@ pub enum LoadAccountsError {
6770
}
6871

6972
const MAX_ATAS_PER_IX: usize = 8;
73+
const MAX_PDAS_PER_IX: usize = 8;
7074

7175
/// Build load instructions for cold accounts. Returns empty vec if all hot.
7276
///
@@ -113,14 +117,18 @@ where
113117
})
114118
.collect();
115119

116-
let pda_hashes = collect_pda_hashes(&cold_pdas)?;
120+
let pda_groups = group_pda_specs(&cold_pdas, MAX_PDAS_PER_IX);
121+
let pda_hashes = pda_groups
122+
.iter()
123+
.map(|group| collect_pda_hashes(group))
124+
.collect::<Result<Vec<_>, _>>()?;
117125
let ata_hashes = collect_ata_hashes(&cold_atas)?;
118126
let mint_hashes = collect_mint_hashes(&cold_mints)?;
119127

120128
let (pda_proofs, ata_proofs, mint_proofs) = futures::join!(
121-
fetch_proofs(&pda_hashes, indexer),
129+
fetch_proof_batches(&pda_hashes, indexer),
122130
fetch_proofs_batched(&ata_hashes, MAX_ATAS_PER_IX, indexer),
123-
fetch_proofs(&mint_hashes, indexer),
131+
fetch_individual_proofs(&mint_hashes, indexer),
124132
);
125133

126134
let pda_proofs = pda_proofs?;
@@ -136,18 +144,17 @@ where
136144

137145
// 2. DecompressAccountsIdempotent for all cold PDAs (including token PDAs).
138146
// Token PDAs are created on-chain via CPI inside DecompressVariant.
139-
for (spec, proof) in cold_pdas.iter().zip(pda_proofs) {
147+
for (group, proof) in pda_groups.into_iter().zip(pda_proofs) {
140148
out.push(build_pda_load(
141-
&[spec],
149+
&group,
142150
proof,
143151
fee_payer,
144152
compression_config,
145153
)?);
146154
}
147155

148156
// 3. ATA loads (CreateAssociatedTokenAccount + Transfer2) - requires mint to exist
149-
let ata_chunks: Vec<_> = cold_atas.chunks(MAX_ATAS_PER_IX).collect();
150-
for (chunk, proof) in ata_chunks.into_iter().zip(ata_proofs) {
157+
for (chunk, proof) in cold_atas.chunks(MAX_ATAS_PER_IX).zip(ata_proofs) {
151158
out.extend(build_ata_load(chunk, proof, fee_payer)?);
152159
}
153160

@@ -195,23 +202,77 @@ fn collect_mint_hashes(ifaces: &[&AccountInterface]) -> Result<Vec<[u8; 32]>, Lo
195202
.collect()
196203
}
197204

198-
async fn fetch_proofs<I: Indexer>(
205+
/// Groups already-ordered PDA specs into contiguous runs of the same program id.
206+
///
207+
/// This preserves input order rather than globally regrouping by program. Callers that
208+
/// want maximal batching across interleaved program ids should sort before calling.
209+
fn group_pda_specs<'a, V>(
210+
specs: &[&'a PdaSpec<V>],
211+
max_per_group: usize,
212+
) -> Vec<Vec<&'a PdaSpec<V>>> {
213+
assert!(max_per_group > 0, "max_per_group must be non-zero");
214+
if specs.is_empty() {
215+
return Vec::new();
216+
}
217+
218+
let mut groups = Vec::new();
219+
let mut current = Vec::with_capacity(max_per_group);
220+
let mut current_program: Option<Pubkey> = None;
221+
222+
for spec in specs {
223+
let program_id = spec.program_id();
224+
let should_split = current_program
225+
.map(|existing| existing != program_id || current.len() >= max_per_group)
226+
.unwrap_or(false);
227+
228+
if should_split {
229+
groups.push(current);
230+
current = Vec::with_capacity(max_per_group);
231+
}
232+
233+
current_program = Some(program_id);
234+
current.push(*spec);
235+
}
236+
237+
if !current.is_empty() {
238+
groups.push(current);
239+
}
240+
241+
groups
242+
}
243+
244+
async fn fetch_individual_proofs<I: Indexer>(
199245
hashes: &[[u8; 32]],
200246
indexer: &I,
201247
) -> Result<Vec<ValidityProofWithContext>, IndexerError> {
202248
if hashes.is_empty() {
203249
return Ok(vec![]);
204250
}
205-
let mut proofs = Vec::with_capacity(hashes.len());
206-
for hash in hashes {
207-
proofs.push(
208-
indexer
209-
.get_validity_proof(vec![*hash], vec![], None)
210-
.await?
211-
.value,
212-
);
251+
252+
futures::future::try_join_all(hashes.iter().map(|hash| async move {
253+
indexer
254+
.get_validity_proof(vec![*hash], vec![], None)
255+
.await
256+
.map(|response| response.value)
257+
}))
258+
.await
259+
}
260+
261+
async fn fetch_proof_batches<I: Indexer>(
262+
hash_batches: &[Vec<[u8; 32]>],
263+
indexer: &I,
264+
) -> Result<Vec<ValidityProofWithContext>, IndexerError> {
265+
if hash_batches.is_empty() {
266+
return Ok(vec![]);
213267
}
214-
Ok(proofs)
268+
269+
futures::future::try_join_all(hash_batches.iter().map(|hashes| async move {
270+
indexer
271+
.get_validity_proof(hashes.clone(), vec![], None)
272+
.await
273+
.map(|response| response.value)
274+
}))
275+
.await
215276
}
216277

217278
async fn fetch_proofs_batched<I: Indexer>(
@@ -222,16 +283,13 @@ async fn fetch_proofs_batched<I: Indexer>(
222283
if hashes.is_empty() {
223284
return Ok(vec![]);
224285
}
225-
let mut proofs = Vec::with_capacity(hashes.len().div_ceil(batch_size));
226-
for chunk in hashes.chunks(batch_size) {
227-
proofs.push(
228-
indexer
229-
.get_validity_proof(chunk.to_vec(), vec![], None)
230-
.await?
231-
.value,
232-
);
233-
}
234-
Ok(proofs)
286+
287+
let hash_batches = hashes
288+
.chunks(batch_size)
289+
.map(|chunk| chunk.to_vec())
290+
.collect::<Vec<_>>();
291+
292+
fetch_proof_batches(&hash_batches, indexer).await
235293
}
236294

237295
fn build_pda_load<V>(
@@ -262,11 +320,16 @@ where
262320
let hot_addresses: Vec<Pubkey> = specs.iter().map(|s| s.address()).collect();
263321
let cold_accounts: Vec<(CompressedAccount, V)> = specs
264322
.iter()
265-
.map(|s| {
266-
let compressed = s.compressed().expect("cold spec must have data").clone();
267-
(compressed, s.variant.clone())
323+
.map(|s| -> Result<_, LoadAccountsError> {
324+
let compressed =
325+
s.compressed()
326+
.cloned()
327+
.ok_or(LoadAccountsError::MissingPdaCompressedData {
328+
pubkey: s.address(),
329+
})?;
330+
Ok((compressed, s.variant.clone()))
268331
})
269-
.collect();
332+
.collect::<Result<Vec<_>, _>>()?;
270333

271334
let program_id = specs.first().map(|s| s.program_id()).unwrap_or_default();
272335

0 commit comments

Comments
 (0)