use futures::channel::oneshot;
use sc_client_api::Backend;
use sp_runtime::traits::Block as BlockT;
use std::{
collections::{hash_map::Entry, HashMap},
sync::Arc,
time::{Duration, Instant},
};
use crate::chain_head::subscription::SubscriptionManagementError;
#[derive(Debug, Clone, PartialEq)]
enum BlockStateMachine {
Registered,
FullyRegistered,
Unpinned,
FullyUnpinned,
}
impl BlockStateMachine {
fn new() -> Self {
BlockStateMachine::Registered
}
fn advance_register(&mut self) {
match self {
BlockStateMachine::Registered => *self = BlockStateMachine::FullyRegistered,
BlockStateMachine::Unpinned => *self = BlockStateMachine::FullyUnpinned,
_ => (),
}
}
fn advance_unpin(&mut self) {
match self {
BlockStateMachine::Registered => *self = BlockStateMachine::Unpinned,
BlockStateMachine::FullyRegistered => *self = BlockStateMachine::FullyUnpinned,
_ => (),
}
}
fn was_unpinned(&self) -> bool {
match self {
BlockStateMachine::Unpinned => true,
BlockStateMachine::FullyUnpinned => true,
_ => false,
}
}
}
struct BlockState {
state_machine: BlockStateMachine,
timestamp: Instant,
}
struct SubscriptionState<Block: BlockT> {
with_runtime: bool,
tx_stop: Option<oneshot::Sender<()>>,
blocks: HashMap<Block::Hash, BlockState>,
}
impl<Block: BlockT> SubscriptionState<Block> {
fn stop(&mut self) {
if let Some(tx_stop) = self.tx_stop.take() {
let _ = tx_stop.send(());
}
}
fn register_block(&mut self, hash: Block::Hash) -> bool {
match self.blocks.entry(hash) {
Entry::Occupied(mut occupied) => {
let block_state = occupied.get_mut();
block_state.state_machine.advance_register();
if block_state.state_machine == BlockStateMachine::FullyUnpinned {
occupied.remove();
}
false
},
Entry::Vacant(vacant) => {
vacant.insert(BlockState {
state_machine: BlockStateMachine::new(),
timestamp: Instant::now(),
});
true
},
}
}
fn unregister_block(&mut self, hash: Block::Hash) -> bool {
match self.blocks.entry(hash) {
Entry::Occupied(mut occupied) => {
let block_state = occupied.get_mut();
if block_state.state_machine.was_unpinned() {
return false
}
block_state.state_machine.advance_unpin();
if block_state.state_machine == BlockStateMachine::FullyUnpinned {
occupied.remove();
}
true
},
Entry::Vacant(_) => false,
}
}
fn contains_block(&self, hash: Block::Hash) -> bool {
let Some(state) = self.blocks.get(&hash) else {
return false
};
!state.state_machine.was_unpinned()
}
fn find_oldest_block_timestamp(&self) -> Instant {
let mut timestamp = Instant::now();
for (_, state) in self.blocks.iter() {
timestamp = std::cmp::min(timestamp, state.timestamp);
}
timestamp
}
}
pub struct BlockGuard<Block: BlockT, BE: Backend<Block>> {
hash: Block::Hash,
with_runtime: bool,
backend: Arc<BE>,
}
impl<Block: BlockT, BE: Backend<Block>> std::fmt::Debug for BlockGuard<Block, BE> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "BlockGuard hash {:?} with_runtime {:?}", self.hash, self.with_runtime)
}
}
impl<Block: BlockT, BE: Backend<Block>> BlockGuard<Block, BE> {
fn new(
hash: Block::Hash,
with_runtime: bool,
backend: Arc<BE>,
) -> Result<Self, SubscriptionManagementError> {
backend
.pin_block(hash)
.map_err(|err| SubscriptionManagementError::Custom(err.to_string()))?;
Ok(Self { hash, with_runtime, backend })
}
pub fn has_runtime(&self) -> bool {
self.with_runtime
}
}
impl<Block: BlockT, BE: Backend<Block>> Drop for BlockGuard<Block, BE> {
fn drop(&mut self) {
self.backend.unpin_block(self.hash);
}
}
pub struct SubscriptionsInner<Block: BlockT, BE: Backend<Block>> {
global_blocks: HashMap<Block::Hash, usize>,
global_max_pinned_blocks: usize,
local_max_pin_duration: Duration,
subs: HashMap<String, SubscriptionState<Block>>,
backend: Arc<BE>,
}
impl<Block: BlockT, BE: Backend<Block>> SubscriptionsInner<Block, BE> {
pub fn new(
global_max_pinned_blocks: usize,
local_max_pin_duration: Duration,
backend: Arc<BE>,
) -> Self {
SubscriptionsInner {
global_blocks: Default::default(),
global_max_pinned_blocks,
local_max_pin_duration,
subs: Default::default(),
backend,
}
}
pub fn insert_subscription(
&mut self,
sub_id: String,
with_runtime: bool,
) -> Option<oneshot::Receiver<()>> {
if let Entry::Vacant(entry) = self.subs.entry(sub_id) {
let (tx_stop, rx_stop) = oneshot::channel();
let state = SubscriptionState::<Block> {
with_runtime,
tx_stop: Some(tx_stop),
blocks: Default::default(),
};
entry.insert(state);
Some(rx_stop)
} else {
None
}
}
pub fn remove_subscription(&mut self, sub_id: &str) {
let Some(mut sub) = self.subs.remove(sub_id) else { return };
sub.stop();
for (hash, state) in sub.blocks.iter() {
if !state.state_machine.was_unpinned() {
self.global_unregister_block(*hash);
}
}
}
fn ensure_block_space(&mut self, request_sub_id: &str) -> bool {
if self.global_blocks.len() < self.global_max_pinned_blocks {
return false
}
let now = Instant::now();
let to_remove: Vec<_> = self
.subs
.iter_mut()
.filter_map(|(sub_id, sub)| {
let sub_time = sub.find_oldest_block_timestamp();
let should_remove = match now.checked_duration_since(sub_time) {
Some(duration) => duration > self.local_max_pin_duration,
None => true,
};
should_remove.then(|| sub_id.clone())
})
.collect();
let mut is_terminated = false;
for sub_id in to_remove {
if sub_id == request_sub_id {
is_terminated = true;
}
self.remove_subscription(&sub_id);
}
if self.global_blocks.len() < self.global_max_pinned_blocks {
return is_terminated
}
let to_remove: Vec<_> = self.subs.keys().map(|sub_id| sub_id.clone()).collect();
for sub_id in to_remove {
if sub_id == request_sub_id {
is_terminated = true;
}
self.remove_subscription(&sub_id);
}
return is_terminated
}
pub fn pin_block(
&mut self,
sub_id: &str,
hash: Block::Hash,
) -> Result<bool, SubscriptionManagementError> {
let Some(sub) = self.subs.get_mut(sub_id) else {
return Err(SubscriptionManagementError::SubscriptionAbsent)
};
if !sub.register_block(hash) {
return Ok(false)
}
if !self.global_blocks.contains_key(&hash) {
if self.ensure_block_space(sub_id) {
return Err(SubscriptionManagementError::ExceededLimits)
}
}
self.global_register_block(hash)?;
Ok(true)
}
fn global_register_block(
&mut self,
hash: Block::Hash,
) -> Result<(), SubscriptionManagementError> {
match self.global_blocks.entry(hash) {
Entry::Occupied(mut occupied) => {
*occupied.get_mut() += 1;
},
Entry::Vacant(vacant) => {
self.backend
.pin_block(hash)
.map_err(|err| SubscriptionManagementError::Custom(err.to_string()))?;
vacant.insert(1);
},
};
Ok(())
}
fn global_unregister_block(&mut self, hash: Block::Hash) {
if let Entry::Occupied(mut occupied) = self.global_blocks.entry(hash) {
let counter = occupied.get_mut();
if *counter == 1 {
self.backend.unpin_block(hash);
occupied.remove();
} else {
*counter -= 1;
}
}
}
pub fn unpin_block(
&mut self,
sub_id: &str,
hash: Block::Hash,
) -> Result<(), SubscriptionManagementError> {
let Some(sub) = self.subs.get_mut(sub_id) else {
return Err(SubscriptionManagementError::SubscriptionAbsent)
};
if !sub.unregister_block(hash) {
return Err(SubscriptionManagementError::BlockHashAbsent)
}
self.global_unregister_block(hash);
Ok(())
}
pub fn lock_block(
&mut self,
sub_id: &str,
hash: Block::Hash,
) -> Result<BlockGuard<Block, BE>, SubscriptionManagementError> {
let Some(sub) = self.subs.get(sub_id) else {
return Err(SubscriptionManagementError::SubscriptionAbsent)
};
if !sub.contains_block(hash) {
return Err(SubscriptionManagementError::BlockHashAbsent)
}
BlockGuard::new(hash, sub.with_runtime, self.backend.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use sc_block_builder::BlockBuilderProvider;
use sc_service::client::new_in_mem;
use sp_consensus::BlockOrigin;
use sp_core::{testing::TaskExecutor, H256};
use substrate_test_runtime_client::{
prelude::*,
runtime::{Block, RuntimeApi},
Client, ClientBlockImportExt, GenesisInit,
};
fn init_backend() -> (
Arc<sc_client_api::in_mem::Backend<Block>>,
Arc<Client<sc_client_api::in_mem::Backend<Block>>>,
) {
let backend = Arc::new(sc_client_api::in_mem::Backend::new());
let executor = substrate_test_runtime_client::new_native_or_wasm_executor();
let client_config = sc_service::ClientConfig::default();
let genesis_block_builder = sc_service::GenesisBlockBuilder::new(
&substrate_test_runtime_client::GenesisParameters::default().genesis_storage(),
!client_config.no_genesis,
backend.clone(),
executor.clone(),
)
.unwrap();
let client = Arc::new(
new_in_mem::<_, Block, _, RuntimeApi>(
backend.clone(),
executor,
genesis_block_builder,
None,
None,
Box::new(TaskExecutor::new()),
client_config,
)
.unwrap(),
);
(backend, client)
}
#[test]
fn block_state_machine_register_unpin() {
let mut state = BlockStateMachine::new();
assert_eq!(state, BlockStateMachine::Registered);
state.advance_register();
assert_eq!(state, BlockStateMachine::FullyRegistered);
state.advance_register();
assert_eq!(state, BlockStateMachine::FullyRegistered);
assert!(!state.was_unpinned());
state.advance_unpin();
assert_eq!(state, BlockStateMachine::FullyUnpinned);
assert!(state.was_unpinned());
state.advance_unpin();
assert_eq!(state, BlockStateMachine::FullyUnpinned);
assert!(state.was_unpinned());
state.advance_register();
assert_eq!(state, BlockStateMachine::FullyUnpinned);
}
#[test]
fn block_state_machine_unpin_register() {
let mut state = BlockStateMachine::new();
assert_eq!(state, BlockStateMachine::Registered);
assert!(!state.was_unpinned());
state.advance_unpin();
assert_eq!(state, BlockStateMachine::Unpinned);
assert!(state.was_unpinned());
state.advance_unpin();
assert_eq!(state, BlockStateMachine::Unpinned);
assert!(state.was_unpinned());
state.advance_register();
assert_eq!(state, BlockStateMachine::FullyUnpinned);
assert!(state.was_unpinned());
state.advance_register();
assert_eq!(state, BlockStateMachine::FullyUnpinned);
state.advance_unpin();
assert_eq!(state, BlockStateMachine::FullyUnpinned);
assert!(state.was_unpinned());
}
#[test]
fn sub_state_register_twice() {
let mut sub_state = SubscriptionState::<Block> {
with_runtime: false,
tx_stop: None,
blocks: Default::default(),
};
let hash = H256::random();
assert_eq!(sub_state.register_block(hash), true);
let block_state = sub_state.blocks.get(&hash).unwrap();
assert_eq!(block_state.state_machine, BlockStateMachine::Registered);
assert_eq!(sub_state.register_block(hash), false);
let block_state = sub_state.blocks.get(&hash).unwrap();
assert_eq!(block_state.state_machine, BlockStateMachine::FullyRegistered);
assert_eq!(sub_state.unregister_block(hash), true);
let block_state = sub_state.blocks.get(&hash);
assert!(block_state.is_none());
}
#[test]
fn sub_state_register_unregister() {
let mut sub_state = SubscriptionState::<Block> {
with_runtime: false,
tx_stop: None,
blocks: Default::default(),
};
let hash = H256::random();
assert_eq!(sub_state.unregister_block(hash), false);
assert_eq!(sub_state.register_block(hash), true);
let block_state = sub_state.blocks.get(&hash).unwrap();
assert_eq!(block_state.state_machine, BlockStateMachine::Registered);
assert_eq!(sub_state.unregister_block(hash), true);
let block_state = sub_state.blocks.get(&hash).unwrap();
assert_eq!(block_state.state_machine, BlockStateMachine::Unpinned);
assert_eq!(sub_state.register_block(hash), false);
let block_state = sub_state.blocks.get(&hash);
assert!(block_state.is_none());
assert_eq!(sub_state.unregister_block(hash), false);
let block_state = sub_state.blocks.get(&hash);
assert!(block_state.is_none());
}
#[test]
fn subscription_lock_block() {
let builder = TestClientBuilder::new();
let backend = builder.backend();
let mut subs = SubscriptionsInner::new(10, Duration::from_secs(10), backend);
let id = "abc".to_string();
let hash = H256::random();
let err = subs.lock_block(&id, hash).unwrap_err();
assert_eq!(err, SubscriptionManagementError::SubscriptionAbsent);
let _stop = subs.insert_subscription(id.clone(), true).unwrap();
assert!(subs.insert_subscription(id.clone(), true).is_none());
let err = subs.lock_block(&id, hash).unwrap_err();
assert_eq!(err, SubscriptionManagementError::BlockHashAbsent);
subs.remove_subscription(&id);
let err = subs.lock_block(&id, hash).unwrap_err();
assert_eq!(err, SubscriptionManagementError::SubscriptionAbsent);
}
#[test]
fn subscription_check_block() {
let (backend, mut client) = init_backend();
let block = client.new_block(Default::default()).unwrap().build().unwrap().block;
let hash = block.header.hash();
futures::executor::block_on(client.import(BlockOrigin::Own, block.clone())).unwrap();
let mut subs = SubscriptionsInner::new(10, Duration::from_secs(10), backend);
let id = "abc".to_string();
let _stop = subs.insert_subscription(id.clone(), true).unwrap();
assert_eq!(subs.pin_block(&id, hash).unwrap(), true);
let block = subs.lock_block(&id, hash).unwrap();
assert_eq!(block.has_runtime(), true);
let invalid_id = "abc-invalid".to_string();
let err = subs.unpin_block(&invalid_id, hash).unwrap_err();
assert_eq!(err, SubscriptionManagementError::SubscriptionAbsent);
subs.unpin_block(&id, hash).unwrap();
let err = subs.lock_block(&id, hash).unwrap_err();
assert_eq!(err, SubscriptionManagementError::BlockHashAbsent);
}
#[test]
fn subscription_ref_count() {
let (backend, mut client) = init_backend();
let block = client.new_block(Default::default()).unwrap().build().unwrap().block;
let hash = block.header.hash();
futures::executor::block_on(client.import(BlockOrigin::Own, block.clone())).unwrap();
let mut subs = SubscriptionsInner::new(10, Duration::from_secs(10), backend);
let id = "abc".to_string();
let _stop = subs.insert_subscription(id.clone(), true).unwrap();
assert_eq!(subs.pin_block(&id, hash).unwrap(), true);
assert_eq!(*subs.global_blocks.get(&hash).unwrap(), 1);
subs.subs.get(&id).unwrap().blocks.get(&hash).unwrap();
assert_eq!(subs.pin_block(&id, hash).unwrap(), false);
assert_eq!(*subs.global_blocks.get(&hash).unwrap(), 1);
let id_second = "abcd".to_string();
let _stop = subs.insert_subscription(id_second.clone(), true).unwrap();
assert_eq!(subs.pin_block(&id_second, hash).unwrap(), true);
assert_eq!(*subs.global_blocks.get(&hash).unwrap(), 2);
subs.subs.get(&id_second).unwrap().blocks.get(&hash).unwrap();
subs.unpin_block(&id, hash).unwrap();
assert_eq!(*subs.global_blocks.get(&hash).unwrap(), 1);
let err = subs.unpin_block(&id, hash).unwrap_err();
assert_eq!(err, SubscriptionManagementError::BlockHashAbsent);
subs.unpin_block(&id_second, hash).unwrap();
assert!(subs.global_blocks.get(&hash).is_none());
}
#[test]
fn subscription_remove_subscription() {
let (backend, mut client) = init_backend();
let block = client.new_block(Default::default()).unwrap().build().unwrap().block;
let hash_1 = block.header.hash();
futures::executor::block_on(client.import(BlockOrigin::Own, block.clone())).unwrap();
let block = client.new_block(Default::default()).unwrap().build().unwrap().block;
let hash_2 = block.header.hash();
futures::executor::block_on(client.import(BlockOrigin::Own, block.clone())).unwrap();
let block = client.new_block(Default::default()).unwrap().build().unwrap().block;
let hash_3 = block.header.hash();
futures::executor::block_on(client.import(BlockOrigin::Own, block.clone())).unwrap();
let mut subs = SubscriptionsInner::new(10, Duration::from_secs(10), backend);
let id_1 = "abc".to_string();
let id_2 = "abcd".to_string();
let _stop = subs.insert_subscription(id_1.clone(), true).unwrap();
assert_eq!(subs.pin_block(&id_1, hash_1).unwrap(), true);
assert_eq!(subs.pin_block(&id_1, hash_2).unwrap(), true);
assert_eq!(subs.pin_block(&id_1, hash_3).unwrap(), true);
let _stop = subs.insert_subscription(id_2.clone(), true).unwrap();
assert_eq!(subs.pin_block(&id_2, hash_2).unwrap(), true);
assert_eq!(*subs.global_blocks.get(&hash_1).unwrap(), 1);
assert_eq!(*subs.global_blocks.get(&hash_2).unwrap(), 2);
assert_eq!(*subs.global_blocks.get(&hash_3).unwrap(), 1);
subs.remove_subscription(&id_1);
assert!(subs.global_blocks.get(&hash_1).is_none());
assert_eq!(*subs.global_blocks.get(&hash_2).unwrap(), 1);
assert!(subs.global_blocks.get(&hash_3).is_none());
subs.remove_subscription(&id_2);
assert!(subs.global_blocks.get(&hash_2).is_none());
assert_eq!(subs.global_blocks.len(), 0);
}
#[test]
fn subscription_check_limits() {
let (backend, mut client) = init_backend();
let block = client.new_block(Default::default()).unwrap().build().unwrap().block;
let hash_1 = block.header.hash();
futures::executor::block_on(client.import(BlockOrigin::Own, block.clone())).unwrap();
let block = client.new_block(Default::default()).unwrap().build().unwrap().block;
let hash_2 = block.header.hash();
futures::executor::block_on(client.import(BlockOrigin::Own, block.clone())).unwrap();
let block = client.new_block(Default::default()).unwrap().build().unwrap().block;
let hash_3 = block.header.hash();
futures::executor::block_on(client.import(BlockOrigin::Own, block.clone())).unwrap();
let mut subs = SubscriptionsInner::new(2, Duration::from_secs(10), backend);
let id_1 = "abc".to_string();
let id_2 = "abcd".to_string();
let _stop = subs.insert_subscription(id_1.clone(), true).unwrap();
assert_eq!(subs.pin_block(&id_1, hash_1).unwrap(), true);
assert_eq!(subs.pin_block(&id_1, hash_2).unwrap(), true);
let _stop = subs.insert_subscription(id_2.clone(), true).unwrap();
assert_eq!(subs.pin_block(&id_2, hash_1).unwrap(), true);
assert_eq!(subs.pin_block(&id_2, hash_2).unwrap(), true);
assert_eq!(*subs.global_blocks.get(&hash_1).unwrap(), 2);
assert_eq!(*subs.global_blocks.get(&hash_2).unwrap(), 2);
let err = subs.pin_block(&id_1, hash_3).unwrap_err();
assert_eq!(err, SubscriptionManagementError::ExceededLimits);
let err = subs.lock_block(&id_1, hash_1).unwrap_err();
assert_eq!(err, SubscriptionManagementError::SubscriptionAbsent);
let err = subs.lock_block(&id_2, hash_1).unwrap_err();
assert_eq!(err, SubscriptionManagementError::SubscriptionAbsent);
assert!(subs.global_blocks.get(&hash_1).is_none());
assert!(subs.global_blocks.get(&hash_2).is_none());
assert!(subs.global_blocks.get(&hash_3).is_none());
assert_eq!(subs.global_blocks.len(), 0);
}
#[test]
fn subscription_check_limits_with_duration() {
let (backend, mut client) = init_backend();
let block = client.new_block(Default::default()).unwrap().build().unwrap().block;
let hash_1 = block.header.hash();
futures::executor::block_on(client.import(BlockOrigin::Own, block.clone())).unwrap();
let block = client.new_block(Default::default()).unwrap().build().unwrap().block;
let hash_2 = block.header.hash();
futures::executor::block_on(client.import(BlockOrigin::Own, block.clone())).unwrap();
let block = client.new_block(Default::default()).unwrap().build().unwrap().block;
let hash_3 = block.header.hash();
futures::executor::block_on(client.import(BlockOrigin::Own, block.clone())).unwrap();
let mut subs = SubscriptionsInner::new(2, Duration::from_secs(5), backend);
let id_1 = "abc".to_string();
let id_2 = "abcd".to_string();
let _stop = subs.insert_subscription(id_1.clone(), true).unwrap();
assert_eq!(subs.pin_block(&id_1, hash_1).unwrap(), true);
assert_eq!(subs.pin_block(&id_1, hash_2).unwrap(), true);
std::thread::sleep(std::time::Duration::from_secs(5));
let _stop = subs.insert_subscription(id_2.clone(), true).unwrap();
assert_eq!(subs.pin_block(&id_2, hash_1).unwrap(), true);
assert_eq!(*subs.global_blocks.get(&hash_1).unwrap(), 2);
assert_eq!(*subs.global_blocks.get(&hash_2).unwrap(), 1);
let err = subs.pin_block(&id_1, hash_3).unwrap_err();
assert_eq!(err, SubscriptionManagementError::ExceededLimits);
let err = subs.lock_block(&id_1, hash_1).unwrap_err();
assert_eq!(err, SubscriptionManagementError::SubscriptionAbsent);
let _block_guard = subs.lock_block(&id_2, hash_1).unwrap();
assert_eq!(*subs.global_blocks.get(&hash_1).unwrap(), 1);
assert!(subs.global_blocks.get(&hash_2).is_none());
assert!(subs.global_blocks.get(&hash_3).is_none());
assert_eq!(subs.global_blocks.len(), 1);
assert_eq!(subs.pin_block(&id_2, hash_2).unwrap(), true);
let err = subs.pin_block(&id_2, hash_3).unwrap_err();
assert_eq!(err, SubscriptionManagementError::ExceededLimits);
assert!(subs.global_blocks.get(&hash_1).is_none());
assert!(subs.global_blocks.get(&hash_2).is_none());
assert!(subs.global_blocks.get(&hash_3).is_none());
assert_eq!(subs.global_blocks.len(), 0);
}
#[test]
fn subscription_check_stop_event() {
let builder = TestClientBuilder::new();
let backend = builder.backend();
let mut subs = SubscriptionsInner::new(10, Duration::from_secs(10), backend);
let id = "abc".to_string();
let mut rx_stop = subs.insert_subscription(id.clone(), true).unwrap();
let res = rx_stop.try_recv().unwrap();
assert!(res.is_none());
let sub = subs.subs.get_mut(&id).unwrap();
sub.stop();
let res = rx_stop.try_recv().unwrap();
assert!(res.is_some());
}
}