Skip to content

Commit 36a79b7

Browse files
committed
impl write and read for resource manager
Adds write and read implementations to persist the DefaultResourceManager.
1 parent 2c7aca7 commit 36a79b7

1 file changed

Lines changed: 271 additions & 4 deletions

File tree

lightning/src/ln/resource_manager.rs

Lines changed: 271 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99

1010
#![allow(dead_code)]
1111

12+
use bitcoin::io::Read;
1213
use core::{fmt::Display, time::Duration};
1314

1415
use crate::{
1516
crypto::chacha20::ChaCha20,
16-
ln::channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS,
17+
io,
18+
ln::{channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS, msgs::DecodeError},
1719
prelude::{hash_map::Entry, new_hash_map, HashMap},
1820
sign::EntropySource,
1921
sync::Mutex,
22+
util::ser::{CollectionLength, Readable, ReadableArgs, Writeable, Writer},
2023
};
2124

2225
/// A trait for managing channel resources and making HTLC forwarding decisions.
@@ -129,6 +132,14 @@ impl Default for ResourceManagerConfig {
129132
}
130133
}
131134

135+
impl_writeable_tlv_based!(ResourceManagerConfig, {
136+
(1, general_allocation_pct, required),
137+
(3, congestion_allocation_pct, required),
138+
(5, resolution_period, required),
139+
(7, revenue_window, required),
140+
(9, reputation_multiplier, required),
141+
});
142+
132143
/// The outcome of an HTLC forwarding decision.
133144
#[derive(PartialEq, Eq, Debug)]
134145
pub enum ForwardingOutcome {
@@ -337,6 +348,47 @@ impl GeneralBucket {
337348
}
338349
}
339350

351+
impl Writeable for GeneralBucket {
352+
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
353+
let channel_info: HashMap<u64, [u8; 32]> =
354+
self.channels_slots.iter().map(|(scid, (_slots, salt))| (*scid, *salt)).collect();
355+
356+
write_tlv_fields!(writer, {
357+
(1, self.scid, required),
358+
(3, self.total_slots, required),
359+
(5, self.total_liquidity, required),
360+
(7, channel_info, required),
361+
});
362+
Ok(())
363+
}
364+
}
365+
366+
impl<ES: EntropySource> ReadableArgs<&ES> for GeneralBucket {
367+
fn read<R: Read>(reader: &mut R, entropy_source: &ES) -> Result<Self, DecodeError> {
368+
_init_and_read_len_prefixed_tlv_fields!(reader, {
369+
(1, our_scid, required),
370+
(3, general_total_slots, required),
371+
(5, general_total_liquidity, required),
372+
(7, channel_info, required),
373+
});
374+
375+
let mut general_bucket = GeneralBucket::new(
376+
our_scid.0.unwrap(),
377+
general_total_slots.0.unwrap(),
378+
general_total_liquidity.0.unwrap(),
379+
);
380+
381+
let channel_info: HashMap<u64, [u8; 32]> = channel_info.0.unwrap();
382+
for (outgoing_scid, salt) in channel_info {
383+
general_bucket
384+
.assign_slots_for_channel(outgoing_scid, Some(salt), entropy_source)
385+
.map_err(|_| DecodeError::InvalidValue)?;
386+
}
387+
388+
Ok(general_bucket)
389+
}
390+
}
391+
340392
struct BucketResources {
341393
slots_allocated: u16,
342394
slots_used: u16,
@@ -374,7 +426,13 @@ impl BucketResources {
374426
}
375427
}
376428

377-
#[derive(Debug, Clone)]
429+
impl_writeable_tlv_based!(BucketResources, {
430+
(1, slots_allocated, required),
431+
(_unused, slots_used, (static_value, 0)),
432+
(3, liquidity_allocated, required),
433+
(_unused, liquidity_used, (static_value, 0)),
434+
});
435+
378436
struct PendingHTLC {
379437
incoming_amount_msat: u64,
380438
fee: u64,
@@ -565,6 +623,42 @@ impl Channel {
565623
}
566624
}
567625

626+
impl Writeable for Channel {
627+
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
628+
write_tlv_fields!(writer, {
629+
(1, self.outgoing_reputation, required),
630+
(3, self.incoming_revenue, required),
631+
(5, self.general_bucket, required),
632+
(7, self.congestion_bucket, required),
633+
(9, self.last_congestion_misuse, required),
634+
(11, self.protected_bucket, required)
635+
});
636+
Ok(())
637+
}
638+
}
639+
640+
impl<ES: EntropySource> ReadableArgs<&ES> for Channel {
641+
fn read<R: Read>(reader: &mut R, entropy_source: &ES) -> Result<Self, DecodeError> {
642+
_init_and_read_len_prefixed_tlv_fields!(reader, {
643+
(1, outgoing_reputation, required),
644+
(3, incoming_revenue, required),
645+
(5, general_bucket, (required: ReadableArgs, entropy_source)),
646+
(7, congestion_bucket, required),
647+
(9, last_congestion_misuse, required),
648+
(11, protected_bucket, required)
649+
});
650+
Ok(Channel {
651+
outgoing_reputation: outgoing_reputation.0.unwrap(),
652+
incoming_revenue: incoming_revenue.0.unwrap(),
653+
general_bucket: general_bucket.0.unwrap(),
654+
pending_htlcs: new_hash_map(),
655+
congestion_bucket: congestion_bucket.0.unwrap(),
656+
last_congestion_misuse: last_congestion_misuse.0.unwrap(),
657+
protected_bucket: protected_bucket.0.unwrap(),
658+
})
659+
}
660+
}
661+
568662
/// An implementation of [`ResourceManager`] for managing channel resources and informing HTLC
569663
/// forwarding decisions. It implements the core of the mitigation as proposed in
570664
/// https://github.com/lightning/bolts/pull/1280.
@@ -814,6 +908,85 @@ impl DefaultResourceManager {
814908
}
815909
}
816910

911+
pub struct PendingHTLCReplay {
912+
pub incoming_channel_id: u64,
913+
pub incoming_amount_msat: u64,
914+
pub incoming_htlc_id: u64,
915+
pub incoming_cltv_expiry: u32,
916+
pub incoming_accountable: bool,
917+
pub outgoing_channel_id: u64,
918+
pub outgoing_amount_msat: u64,
919+
pub added_at_unix_seconds: u64,
920+
pub height_added: u32,
921+
}
922+
923+
impl Writeable for DefaultResourceManager {
924+
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
925+
let channels = self.channels.lock().unwrap();
926+
write_tlv_fields!(writer, {
927+
(1, self.config, required),
928+
(3, channels, required),
929+
});
930+
Ok(())
931+
}
932+
}
933+
934+
impl<ES: EntropySource> ReadableArgs<&ES> for DefaultResourceManager {
935+
fn read<R: Read>(
936+
reader: &mut R, entropy_source: &ES,
937+
) -> Result<DefaultResourceManager, DecodeError> {
938+
_init_and_read_len_prefixed_tlv_fields!(reader, {
939+
(1, config, required),
940+
(3, channels, (required: ReadableArgs, entropy_source)),
941+
});
942+
let channels: HashMap<u64, Channel> = channels.0.unwrap();
943+
Ok(DefaultResourceManager { config: config.0.unwrap(), channels: Mutex::new(channels) })
944+
}
945+
}
946+
947+
impl<ES: EntropySource> ReadableArgs<&ES> for HashMap<u64, Channel> {
948+
fn read<R: Read>(r: &mut R, entropy_source: &ES) -> Result<Self, DecodeError> {
949+
let len: CollectionLength = Readable::read(r)?;
950+
let mut ret = new_hash_map();
951+
for _ in 0..len.0 {
952+
let k: u64 = Readable::read(r)?;
953+
let v = Channel::read(r, entropy_source)?;
954+
if ret.insert(k, v).is_some() {
955+
return Err(DecodeError::InvalidValue);
956+
}
957+
}
958+
Ok(ret)
959+
}
960+
}
961+
962+
impl DefaultResourceManager {
963+
// This should only be called once during startup to replay pending HTLCs we had before
964+
// shutdown.
965+
pub fn replay_pending_htlcs<ES: EntropySource>(
966+
&self, pending_htlcs: &[PendingHTLCReplay], entropy_source: &ES,
967+
) -> Result<Vec<ForwardingOutcome>, DecodeError> {
968+
let mut forwarding_outcomes = Vec::with_capacity(pending_htlcs.len());
969+
for htlc in pending_htlcs {
970+
forwarding_outcomes.push(
971+
self.add_htlc(
972+
htlc.incoming_channel_id,
973+
htlc.incoming_amount_msat,
974+
htlc.incoming_cltv_expiry,
975+
htlc.outgoing_channel_id,
976+
htlc.outgoing_amount_msat,
977+
htlc.incoming_accountable,
978+
htlc.incoming_htlc_id,
979+
htlc.height_added,
980+
htlc.added_at_unix_seconds,
981+
entropy_source,
982+
)
983+
.map_err(|_| DecodeError::InvalidValue)?,
984+
);
985+
}
986+
Ok(forwarding_outcomes)
987+
}
988+
}
989+
817990
/// A weighted average that decays over a specified window.
818991
///
819992
/// It enables tracking of historical behavior without storing individual data points.
@@ -861,6 +1034,16 @@ impl DecayingAverage {
8611034
}
8621035
}
8631036

1037+
impl_writeable_tlv_based!(DecayingAverage, {
1038+
(1, value, required),
1039+
(3, last_updated_unix_secs, required),
1040+
(5, window, required),
1041+
(_unused, decay_rate, (static_value, {
1042+
let w: Duration = window.0.unwrap();
1043+
0.5_f64.powf(2.0 / w.as_secs_f64())
1044+
})),
1045+
});
1046+
8641047
/// Tracks an average value over multiple rolling windows to smooth out volatility.
8651048
///
8661049
/// It tracks the average value using a single window duration but extends observation over
@@ -925,6 +1108,13 @@ impl AggregatedWindowAverage {
9251108
}
9261109
}
9271110

1111+
impl_writeable_tlv_based!(AggregatedWindowAverage, {
1112+
(1, start_timestamp_unix_secs, required),
1113+
(3, window_count, required),
1114+
(5, window_duration, required),
1115+
(7, aggregated_revenue_decaying, required),
1116+
});
1117+
9281118
#[cfg(test)]
9291119
mod tests {
9301120
use std::time::{Duration, SystemTime, UNIX_EPOCH};
@@ -936,12 +1126,15 @@ mod tests {
9361126
channel::TOTAL_BITCOIN_SUPPLY_SATOSHIS,
9371127
resource_manager::{
9381128
AggregatedWindowAverage, BucketAssigned, BucketResources, Channel, DecayingAverage,
939-
DefaultResourceManager, ForwardingOutcome, GeneralBucket, HtlcRef, ResourceManager,
1129+
DefaultResourceManager, ForwardingOutcome, GeneralBucket, HtlcRef,
9401130
ResourceManagerConfig,
9411131
},
9421132
},
9431133
sign::EntropySource,
944-
util::test_utils::TestKeysInterface,
1134+
util::{
1135+
ser::{ReadableArgs, Writeable},
1136+
test_utils::TestKeysInterface,
1137+
},
9451138
};
9461139

9471140
const WINDOW: Duration = Duration::from_secs(2016 * 10 * 60);
@@ -1311,6 +1504,13 @@ mod tests {
13111504
outgoing_channel.outgoing_reputation.add_value(target_reputation, now).unwrap();
13121505
}
13131506

1507+
fn add_revenue(rm: &DefaultResourceManager, incoming_scid: u64, revenue: i64) {
1508+
let mut channels = rm.channels.lock().unwrap();
1509+
let channel = channels.get_mut(&incoming_scid).unwrap();
1510+
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
1511+
channel.incoming_revenue.add_value(revenue, now).unwrap();
1512+
}
1513+
13141514
fn fill_general_bucket(rm: &DefaultResourceManager, incoming_scid: u64) {
13151515
let mut channels = rm.channels.lock().unwrap();
13161516
let incoming_channel = channels.get_mut(&incoming_scid).unwrap();
@@ -2206,6 +2406,73 @@ mod tests {
22062406
assert!(get_htlc_bucket(&rm, INCOMING_SCID, htlc_id, OUTGOING_SCID_2).is_none());
22072407
}
22082408

2409+
#[test]
2410+
fn test_simple_manager_serialize_deserialize() {
2411+
// This is not a complete test of the serialization/deserialization of the resource
2412+
// manager because the pending HTLCs will be replayed through `replay_pending_htlcs` by
2413+
// the upstream i.e ChannelManager.
2414+
let rm = create_test_resource_manager_with_channels();
2415+
let entropy_source = TestKeysInterface::new(&[0; 32], Network::Testnet);
2416+
2417+
add_test_htlc(&rm, false, 0, None, &entropy_source).unwrap();
2418+
2419+
let reputation = 50_000_000;
2420+
add_reputation(&rm, OUTGOING_SCID, reputation);
2421+
2422+
let revenue = 70_000_000;
2423+
add_revenue(&rm, INCOMING_SCID, revenue);
2424+
2425+
let serialized_rm = rm.encode();
2426+
2427+
let channels = rm.channels.lock().unwrap();
2428+
let expected_incoming_channel = channels.get(&INCOMING_SCID).unwrap();
2429+
let (expected_slots, expected_salt) = expected_incoming_channel
2430+
.general_bucket
2431+
.channels_slots
2432+
.get(&OUTGOING_SCID)
2433+
.unwrap()
2434+
.clone();
2435+
2436+
let deserialized_rm =
2437+
DefaultResourceManager::read(&mut serialized_rm.as_slice(), &entropy_source).unwrap();
2438+
let deserialized_channels = deserialized_rm.channels.lock().unwrap();
2439+
assert_eq!(2, deserialized_channels.len());
2440+
2441+
let outgoing_channel = deserialized_channels.get(&OUTGOING_SCID).unwrap();
2442+
assert!(outgoing_channel.general_bucket.channels_slots.is_empty());
2443+
2444+
assert_eq!(outgoing_channel.outgoing_reputation.value, reputation);
2445+
2446+
let incoming_channel = deserialized_channels.get(&INCOMING_SCID).unwrap();
2447+
assert_eq!(incoming_channel.incoming_revenue.aggregated_revenue_decaying.value, revenue);
2448+
2449+
assert_eq!(incoming_channel.general_bucket.channels_slots.len(), 1);
2450+
2451+
let (slots, salt) =
2452+
incoming_channel.general_bucket.channels_slots.get(&OUTGOING_SCID).unwrap().clone();
2453+
assert_eq!(slots, expected_slots);
2454+
assert_eq!(salt, expected_salt);
2455+
2456+
let congestion_bucket = &incoming_channel.congestion_bucket;
2457+
assert_eq!(
2458+
congestion_bucket.slots_allocated,
2459+
expected_incoming_channel.congestion_bucket.slots_allocated
2460+
);
2461+
assert_eq!(
2462+
congestion_bucket.liquidity_allocated,
2463+
expected_incoming_channel.congestion_bucket.liquidity_allocated
2464+
);
2465+
let protected_bucket = &incoming_channel.protected_bucket;
2466+
assert_eq!(
2467+
protected_bucket.slots_allocated,
2468+
expected_incoming_channel.protected_bucket.slots_allocated
2469+
);
2470+
assert_eq!(
2471+
protected_bucket.liquidity_allocated,
2472+
expected_incoming_channel.protected_bucket.liquidity_allocated
2473+
);
2474+
}
2475+
22092476
#[test]
22102477
fn test_decaying_average_values() {
22112478
// Test average decay at different timestamps. The values we are asserting have been

0 commit comments

Comments
 (0)