99
1010#![ allow( dead_code) ]
1111
12+ use bitcoin:: io:: Read ;
1213use core:: { fmt:: Display , time:: Duration } ;
1314
1415use crate :: {
1516 crypto:: chacha20:: ChaCha20 ,
16- ln:: channel:: TOTAL_BITCOIN_SUPPLY_SATOSHIS ,
17+ io,
18+ ln:: { channel:: TOTAL_BITCOIN_SUPPLY_SATOSHIS , msgs:: DecodeError } ,
1719 prelude:: { hash_map:: Entry , new_hash_map, HashMap } ,
1820 sign:: EntropySource ,
1921 sync:: Mutex ,
22+ util:: ser:: { CollectionLength , Readable , ReadableArgs , Writeable , Writer } ,
2023} ;
2124
2225/// A trait for managing channel resources and making HTLC forwarding decisions.
@@ -129,6 +132,14 @@ impl Default for ResourceManagerConfig {
129132 }
130133}
131134
135+ impl_writeable_tlv_based ! ( ResourceManagerConfig , {
136+ ( 1 , general_allocation_pct, required) ,
137+ ( 3 , congestion_allocation_pct, required) ,
138+ ( 5 , resolution_period, required) ,
139+ ( 7 , revenue_window, required) ,
140+ ( 9 , reputation_multiplier, required) ,
141+ } ) ;
142+
132143/// The outcome of an HTLC forwarding decision.
133144#[ derive( PartialEq , Eq , Debug ) ]
134145pub enum ForwardingOutcome {
@@ -337,6 +348,47 @@ impl GeneralBucket {
337348 }
338349}
339350
351+ impl Writeable for GeneralBucket {
352+ fn write < W : Writer > ( & self , writer : & mut W ) -> Result < ( ) , io:: Error > {
353+ let channel_info: HashMap < u64 , [ u8 ; 32 ] > =
354+ self . channels_slots . iter ( ) . map ( |( scid, ( _slots, salt) ) | ( * scid, * salt) ) . collect ( ) ;
355+
356+ write_tlv_fields ! ( writer, {
357+ ( 1 , self . scid, required) ,
358+ ( 3 , self . total_slots, required) ,
359+ ( 5 , self . total_liquidity, required) ,
360+ ( 7 , channel_info, required) ,
361+ } ) ;
362+ Ok ( ( ) )
363+ }
364+ }
365+
366+ impl < ES : EntropySource > ReadableArgs < & ES > for GeneralBucket {
367+ fn read < R : Read > ( reader : & mut R , entropy_source : & ES ) -> Result < Self , DecodeError > {
368+ _init_and_read_len_prefixed_tlv_fields ! ( reader, {
369+ ( 1 , our_scid, required) ,
370+ ( 3 , general_total_slots, required) ,
371+ ( 5 , general_total_liquidity, required) ,
372+ ( 7 , channel_info, required) ,
373+ } ) ;
374+
375+ let mut general_bucket = GeneralBucket :: new (
376+ our_scid. 0 . unwrap ( ) ,
377+ general_total_slots. 0 . unwrap ( ) ,
378+ general_total_liquidity. 0 . unwrap ( ) ,
379+ ) ;
380+
381+ let channel_info: HashMap < u64 , [ u8 ; 32 ] > = channel_info. 0 . unwrap ( ) ;
382+ for ( outgoing_scid, salt) in channel_info {
383+ general_bucket
384+ . assign_slots_for_channel ( outgoing_scid, Some ( salt) , entropy_source)
385+ . map_err ( |_| DecodeError :: InvalidValue ) ?;
386+ }
387+
388+ Ok ( general_bucket)
389+ }
390+ }
391+
340392struct BucketResources {
341393 slots_allocated : u16 ,
342394 slots_used : u16 ,
@@ -374,7 +426,13 @@ impl BucketResources {
374426 }
375427}
376428
377- #[ derive( Debug , Clone ) ]
429+ impl_writeable_tlv_based ! ( BucketResources , {
430+ ( 1 , slots_allocated, required) ,
431+ ( _unused, slots_used, ( static_value, 0 ) ) ,
432+ ( 3 , liquidity_allocated, required) ,
433+ ( _unused, liquidity_used, ( static_value, 0 ) ) ,
434+ } ) ;
435+
378436struct PendingHTLC {
379437 incoming_amount_msat : u64 ,
380438 fee : u64 ,
@@ -565,6 +623,42 @@ impl Channel {
565623 }
566624}
567625
626+ impl Writeable for Channel {
627+ fn write < W : Writer > ( & self , writer : & mut W ) -> Result < ( ) , io:: Error > {
628+ write_tlv_fields ! ( writer, {
629+ ( 1 , self . outgoing_reputation, required) ,
630+ ( 3 , self . incoming_revenue, required) ,
631+ ( 5 , self . general_bucket, required) ,
632+ ( 7 , self . congestion_bucket, required) ,
633+ ( 9 , self . last_congestion_misuse, required) ,
634+ ( 11 , self . protected_bucket, required)
635+ } ) ;
636+ Ok ( ( ) )
637+ }
638+ }
639+
640+ impl < ES : EntropySource > ReadableArgs < & ES > for Channel {
641+ fn read < R : Read > ( reader : & mut R , entropy_source : & ES ) -> Result < Self , DecodeError > {
642+ _init_and_read_len_prefixed_tlv_fields ! ( reader, {
643+ ( 1 , outgoing_reputation, required) ,
644+ ( 3 , incoming_revenue, required) ,
645+ ( 5 , general_bucket, ( required: ReadableArgs , entropy_source) ) ,
646+ ( 7 , congestion_bucket, required) ,
647+ ( 9 , last_congestion_misuse, required) ,
648+ ( 11 , protected_bucket, required)
649+ } ) ;
650+ Ok ( Channel {
651+ outgoing_reputation : outgoing_reputation. 0 . unwrap ( ) ,
652+ incoming_revenue : incoming_revenue. 0 . unwrap ( ) ,
653+ general_bucket : general_bucket. 0 . unwrap ( ) ,
654+ pending_htlcs : new_hash_map ( ) ,
655+ congestion_bucket : congestion_bucket. 0 . unwrap ( ) ,
656+ last_congestion_misuse : last_congestion_misuse. 0 . unwrap ( ) ,
657+ protected_bucket : protected_bucket. 0 . unwrap ( ) ,
658+ } )
659+ }
660+ }
661+
568662/// An implementation of [`ResourceManager`] for managing channel resources and informing HTLC
569663/// forwarding decisions. It implements the core of the mitigation as proposed in
570664/// https://github.com/lightning/bolts/pull/1280.
@@ -814,6 +908,85 @@ impl DefaultResourceManager {
814908 }
815909}
816910
911+ pub struct PendingHTLCReplay {
912+ pub incoming_channel_id : u64 ,
913+ pub incoming_amount_msat : u64 ,
914+ pub incoming_htlc_id : u64 ,
915+ pub incoming_cltv_expiry : u32 ,
916+ pub incoming_accountable : bool ,
917+ pub outgoing_channel_id : u64 ,
918+ pub outgoing_amount_msat : u64 ,
919+ pub added_at_unix_seconds : u64 ,
920+ pub height_added : u32 ,
921+ }
922+
923+ impl Writeable for DefaultResourceManager {
924+ fn write < W : Writer > ( & self , writer : & mut W ) -> Result < ( ) , io:: Error > {
925+ let channels = self . channels . lock ( ) . unwrap ( ) ;
926+ write_tlv_fields ! ( writer, {
927+ ( 1 , self . config, required) ,
928+ ( 3 , channels, required) ,
929+ } ) ;
930+ Ok ( ( ) )
931+ }
932+ }
933+
934+ impl < ES : EntropySource > ReadableArgs < & ES > for DefaultResourceManager {
935+ fn read < R : Read > (
936+ reader : & mut R , entropy_source : & ES ,
937+ ) -> Result < DefaultResourceManager , DecodeError > {
938+ _init_and_read_len_prefixed_tlv_fields ! ( reader, {
939+ ( 1 , config, required) ,
940+ ( 3 , channels, ( required: ReadableArgs , entropy_source) ) ,
941+ } ) ;
942+ let channels: HashMap < u64 , Channel > = channels. 0 . unwrap ( ) ;
943+ Ok ( DefaultResourceManager { config : config. 0 . unwrap ( ) , channels : Mutex :: new ( channels) } )
944+ }
945+ }
946+
947+ impl < ES : EntropySource > ReadableArgs < & ES > for HashMap < u64 , Channel > {
948+ fn read < R : Read > ( r : & mut R , entropy_source : & ES ) -> Result < Self , DecodeError > {
949+ let len: CollectionLength = Readable :: read ( r) ?;
950+ let mut ret = new_hash_map ( ) ;
951+ for _ in 0 ..len. 0 {
952+ let k: u64 = Readable :: read ( r) ?;
953+ let v = Channel :: read ( r, entropy_source) ?;
954+ if ret. insert ( k, v) . is_some ( ) {
955+ return Err ( DecodeError :: InvalidValue ) ;
956+ }
957+ }
958+ Ok ( ret)
959+ }
960+ }
961+
962+ impl DefaultResourceManager {
963+ // This should only be called once during startup to replay pending HTLCs we had before
964+ // shutdown.
965+ pub fn replay_pending_htlcs < ES : EntropySource > (
966+ & self , pending_htlcs : & [ PendingHTLCReplay ] , entropy_source : & ES ,
967+ ) -> Result < Vec < ForwardingOutcome > , DecodeError > {
968+ let mut forwarding_outcomes = Vec :: with_capacity ( pending_htlcs. len ( ) ) ;
969+ for htlc in pending_htlcs {
970+ forwarding_outcomes. push (
971+ self . add_htlc (
972+ htlc. incoming_channel_id ,
973+ htlc. incoming_amount_msat ,
974+ htlc. incoming_cltv_expiry ,
975+ htlc. outgoing_channel_id ,
976+ htlc. outgoing_amount_msat ,
977+ htlc. incoming_accountable ,
978+ htlc. incoming_htlc_id ,
979+ htlc. height_added ,
980+ htlc. added_at_unix_seconds ,
981+ entropy_source,
982+ )
983+ . map_err ( |_| DecodeError :: InvalidValue ) ?,
984+ ) ;
985+ }
986+ Ok ( forwarding_outcomes)
987+ }
988+ }
989+
817990/// A weighted average that decays over a specified window.
818991///
819992/// It enables tracking of historical behavior without storing individual data points.
@@ -861,6 +1034,16 @@ impl DecayingAverage {
8611034 }
8621035}
8631036
1037+ impl_writeable_tlv_based ! ( DecayingAverage , {
1038+ ( 1 , value, required) ,
1039+ ( 3 , last_updated_unix_secs, required) ,
1040+ ( 5 , window, required) ,
1041+ ( _unused, decay_rate, ( static_value, {
1042+ let w: Duration = window. 0 . unwrap( ) ;
1043+ 0.5_f64 . powf( 2.0 / w. as_secs_f64( ) )
1044+ } ) ) ,
1045+ } ) ;
1046+
8641047/// Tracks an average value over multiple rolling windows to smooth out volatility.
8651048///
8661049/// It tracks the average value using a single window duration but extends observation over
@@ -925,6 +1108,13 @@ impl AggregatedWindowAverage {
9251108 }
9261109}
9271110
1111+ impl_writeable_tlv_based ! ( AggregatedWindowAverage , {
1112+ ( 1 , start_timestamp_unix_secs, required) ,
1113+ ( 3 , window_count, required) ,
1114+ ( 5 , window_duration, required) ,
1115+ ( 7 , aggregated_revenue_decaying, required) ,
1116+ } ) ;
1117+
9281118#[ cfg( test) ]
9291119mod tests {
9301120 use std:: time:: { Duration , SystemTime , UNIX_EPOCH } ;
@@ -936,12 +1126,15 @@ mod tests {
9361126 channel:: TOTAL_BITCOIN_SUPPLY_SATOSHIS ,
9371127 resource_manager:: {
9381128 AggregatedWindowAverage , BucketAssigned , BucketResources , Channel , DecayingAverage ,
939- DefaultResourceManager , ForwardingOutcome , GeneralBucket , HtlcRef , ResourceManager ,
1129+ DefaultResourceManager , ForwardingOutcome , GeneralBucket , HtlcRef ,
9401130 ResourceManagerConfig ,
9411131 } ,
9421132 } ,
9431133 sign:: EntropySource ,
944- util:: test_utils:: TestKeysInterface ,
1134+ util:: {
1135+ ser:: { ReadableArgs , Writeable } ,
1136+ test_utils:: TestKeysInterface ,
1137+ } ,
9451138 } ;
9461139
9471140 const WINDOW : Duration = Duration :: from_secs ( 2016 * 10 * 60 ) ;
@@ -1311,6 +1504,13 @@ mod tests {
13111504 outgoing_channel. outgoing_reputation . add_value ( target_reputation, now) . unwrap ( ) ;
13121505 }
13131506
1507+ fn add_revenue ( rm : & DefaultResourceManager , incoming_scid : u64 , revenue : i64 ) {
1508+ let mut channels = rm. channels . lock ( ) . unwrap ( ) ;
1509+ let channel = channels. get_mut ( & incoming_scid) . unwrap ( ) ;
1510+ let now = SystemTime :: now ( ) . duration_since ( UNIX_EPOCH ) . unwrap ( ) . as_secs ( ) ;
1511+ channel. incoming_revenue . add_value ( revenue, now) . unwrap ( ) ;
1512+ }
1513+
13141514 fn fill_general_bucket ( rm : & DefaultResourceManager , incoming_scid : u64 ) {
13151515 let mut channels = rm. channels . lock ( ) . unwrap ( ) ;
13161516 let incoming_channel = channels. get_mut ( & incoming_scid) . unwrap ( ) ;
@@ -2206,6 +2406,73 @@ mod tests {
22062406 assert ! ( get_htlc_bucket( & rm, INCOMING_SCID , htlc_id, OUTGOING_SCID_2 ) . is_none( ) ) ;
22072407 }
22082408
2409+ #[ test]
2410+ fn test_simple_manager_serialize_deserialize ( ) {
2411+ // This is not a complete test of the serialization/deserialization of the resource
2412+ // manager because the pending HTLCs will be replayed through `replay_pending_htlcs` by
2413+ // the upstream i.e ChannelManager.
2414+ let rm = create_test_resource_manager_with_channels ( ) ;
2415+ let entropy_source = TestKeysInterface :: new ( & [ 0 ; 32 ] , Network :: Testnet ) ;
2416+
2417+ add_test_htlc ( & rm, false , 0 , None , & entropy_source) . unwrap ( ) ;
2418+
2419+ let reputation = 50_000_000 ;
2420+ add_reputation ( & rm, OUTGOING_SCID , reputation) ;
2421+
2422+ let revenue = 70_000_000 ;
2423+ add_revenue ( & rm, INCOMING_SCID , revenue) ;
2424+
2425+ let serialized_rm = rm. encode ( ) ;
2426+
2427+ let channels = rm. channels . lock ( ) . unwrap ( ) ;
2428+ let expected_incoming_channel = channels. get ( & INCOMING_SCID ) . unwrap ( ) ;
2429+ let ( expected_slots, expected_salt) = expected_incoming_channel
2430+ . general_bucket
2431+ . channels_slots
2432+ . get ( & OUTGOING_SCID )
2433+ . unwrap ( )
2434+ . clone ( ) ;
2435+
2436+ let deserialized_rm =
2437+ DefaultResourceManager :: read ( & mut serialized_rm. as_slice ( ) , & entropy_source) . unwrap ( ) ;
2438+ let deserialized_channels = deserialized_rm. channels . lock ( ) . unwrap ( ) ;
2439+ assert_eq ! ( 2 , deserialized_channels. len( ) ) ;
2440+
2441+ let outgoing_channel = deserialized_channels. get ( & OUTGOING_SCID ) . unwrap ( ) ;
2442+ assert ! ( outgoing_channel. general_bucket. channels_slots. is_empty( ) ) ;
2443+
2444+ assert_eq ! ( outgoing_channel. outgoing_reputation. value, reputation) ;
2445+
2446+ let incoming_channel = deserialized_channels. get ( & INCOMING_SCID ) . unwrap ( ) ;
2447+ assert_eq ! ( incoming_channel. incoming_revenue. aggregated_revenue_decaying. value, revenue) ;
2448+
2449+ assert_eq ! ( incoming_channel. general_bucket. channels_slots. len( ) , 1 ) ;
2450+
2451+ let ( slots, salt) =
2452+ incoming_channel. general_bucket . channels_slots . get ( & OUTGOING_SCID ) . unwrap ( ) . clone ( ) ;
2453+ assert_eq ! ( slots, expected_slots) ;
2454+ assert_eq ! ( salt, expected_salt) ;
2455+
2456+ let congestion_bucket = & incoming_channel. congestion_bucket ;
2457+ assert_eq ! (
2458+ congestion_bucket. slots_allocated,
2459+ expected_incoming_channel. congestion_bucket. slots_allocated
2460+ ) ;
2461+ assert_eq ! (
2462+ congestion_bucket. liquidity_allocated,
2463+ expected_incoming_channel. congestion_bucket. liquidity_allocated
2464+ ) ;
2465+ let protected_bucket = & incoming_channel. protected_bucket ;
2466+ assert_eq ! (
2467+ protected_bucket. slots_allocated,
2468+ expected_incoming_channel. protected_bucket. slots_allocated
2469+ ) ;
2470+ assert_eq ! (
2471+ protected_bucket. liquidity_allocated,
2472+ expected_incoming_channel. protected_bucket. liquidity_allocated
2473+ ) ;
2474+ }
2475+
22092476 #[ test]
22102477 fn test_decaying_average_values ( ) {
22112478 // Test average decay at different timestamps. The values we are asserting have been
0 commit comments