HCElection.aes
1 include "List.aes" 2 include "Pair.aes" 3 include "String.aes" 4 5 contract interface MainStaking = 6 entrypoint sorted_validators : () => list((address * int)) 7 entrypoint lock_stake : (int) => list((address * int)) 8 entrypoint add_rewards : (int, list(address * int)) => unit 9 entrypoint add_penalties : (int, list(address * int)) => unit 10 11 main contract HCElection = 12 record epoch_info = 13 { start : int, 14 length : int, 15 seed : option(bytes()), 16 staking_distribution : option(list(address * int)) 17 } 18 19 record pin_reward_info = 20 { base : int, 21 current : int, 22 carry_over : int 23 } 24 25 record vote = 26 { producer : address, 27 hash : hash, 28 sign_data : bytes(), 29 signature : signature 30 } 31 32 record epoch_length_vote = 33 { producer : address, 34 epoch_delta : int, 35 sign_data : bytes(), 36 signature : signature 37 } 38 39 record finalize_info = 40 { epoch : int, 41 fork : option(bytes()), 42 epoch_length : option(int), 43 pc_hash : bytes(), 44 producer : address, 45 votes : option(list(vote)), 46 length_votes : option(list(epoch_length_vote)) 47 } 48 49 record state = 50 { main_staking_ct : MainStaking, 51 leader : address, 52 rewards : map(int, list(address * int)), 53 penalties : map(int, list(address * int)), 54 epoch : int, 55 epochs : map(int, epoch_info), 56 pin : option(bytes()), 57 pin_reward : pin_reward_info, 58 finalize : option(finalize_info()) 59 } 60 61 entrypoint init(main_staking_ct : MainStaking) = 62 { main_staking_ct = main_staking_ct, 63 leader = Contract.address, 64 rewards = {}, 65 epoch = 0, 66 epochs = {}, 67 pin = None, 68 finalize = None, 69 penalties = {}, 70 pin_reward = {base = 0, current = 0, carry_over = 0} } 71 72 stateful entrypoint init_epochs(epoch_length : int, base_pin_reward : int) = 73 assert_protocol_call() 74 require(Chain.block_height == 0, "Only in genesis") 75 put(state{ epochs = { [0] = mk_epoch_info(0, 1, None, None), 76 [1] = mk_epoch_info(1, epoch_length, None, Some(state.main_staking_ct.lock_stake(1))), 77 [2] = mk_epoch_info(epoch_length + 1, epoch_length, None, Some(state.main_staking_ct.lock_stake(2))), 78 [3] = mk_epoch_info(2 * epoch_length + 1, epoch_length, None, Some(state.main_staking_ct.lock_stake(3))), 79 [4] = mk_epoch_info(3 * epoch_length + 1, epoch_length, None, Some(state.main_staking_ct.lock_stake(4))) 80 }, 81 epoch = 1, 82 pin_reward.base = base_pin_reward 83 }) 84 85 function mk_epoch_info(start : int, 86 length : int, 87 seed : option(bytes), 88 staking_distribution : option(list(address * int))) = 89 {start = start, length = length, seed = seed, staking_distribution = staking_distribution} 90 91 stateful entrypoint step(leader : address) = 92 assert_protocol_call() 93 put(state{ leader = leader }) 94 95 stateful entrypoint step_micro(leader : address) = 96 assert_protocol_call() 97 put(state{ leader = leader }) 98 99 stateful entrypoint step_eoe(leader : address, seed : bytes(), 100 next_base_pin_reward : int, carry_over_flag : bool) = 101 assert_protocol_call() 102 let epoch = state.epoch 103 let ei = state.epochs[epoch] 104 // pin rewards 105 let next_carry_over = calc_carry_over(carry_over_flag) 106 let next_base = calc_next_base_reward(next_base_pin_reward) 107 let next_reward = next_base + next_carry_over 108 let pr = {current = next_reward, carry_over = next_carry_over, base = next_base} 109 110 collect_penalties(state.epoch - 1) 111 // pay rewards to sunset epoch 112 pay_rewards(state.epoch - 1) 113 114 // update epochs 115 require(ei.start + ei.length - 1 == Chain.block_height, String.concats(["This is not the end: ", Int.to_str(ei.start + ei.length - 1), " : ", Int.to_str(Chain.block_height)])) 116 let ei_adjust = 117 switch(state.finalize) 118 None => state.epochs[epoch + 3] 119 Some(info) => 120 if ( info.epoch == epoch ) 121 switch(info.epoch_length) 122 None => state.epochs[epoch + 3] 123 Some(epoch_length) => 124 state.epochs[epoch + 3]{ length = epoch_length } 125 else 126 state.epochs[epoch + 3] 127 128 let new_epochs = { [epoch] = state.epochs[epoch], 129 [epoch + 1] = state.epochs[epoch + 1], 130 [epoch + 2] = state.epochs[epoch + 2]{ seed = Some(seed) }, 131 [epoch + 3] = ei_adjust, 132 [epoch + 4] = mk_epoch_info(ei_adjust.start + ei_adjust.length, ei_adjust.length, None, 133 Some(state.main_staking_ct.lock_stake(epoch + 4))) 134 } 135 136 put(state{ leader = leader, 137 epoch = epoch + 1, 138 epochs = new_epochs, 139 pin = None, 140 pin_reward = pr}) 141 142 stateful entrypoint pin(proof : bytes()) = 143 let epoch = state.epoch 144 let last = state.epochs[epoch].start + state.epochs[epoch].length - 1 145 require(Chain.block_height == last, "Only in last block") 146 require(Call.caller == state.leader, "Must be called by the last leader of epoch") 147 put(state{pin = Some(proof)}) 148 149 payable stateful entrypoint add_reward(height : int, to : address) = 150 assert_protocol_call() 151 put(state{rewards[height = []] @ rs = (to, Call.value) :: rs}) 152 153 stateful entrypoint add_penalty(height : int, penalty : int, to : address) = 154 assert_protocol_call() 155 require(penalty > 0, "Penalty is expressed as a positive integer") 156 put(state{penalties[height = []] @ rs = (to, penalty) :: rs}) 157 158 stateful entrypoint add_reported_penalty(height : int, penalty : int, to : address, reporter : address, reporter_percentage : int) = 159 assert_protocol_call() 160 require(penalty > 0, "Penalty is expressed as a positive integer") 161 require(reporter_percentage >= 0 && reporter_percentage =< 100, "reporter reward percentage must be between 0 and 100") 162 put(state{penalties[height = []] @ rs = (to, (penalty)) :: rs}) 163 let reporter_reward = (penalty * reporter_percentage) / 100 164 put(state{penalties[height = []] @ rs = (reporter, -reporter_reward) :: rs}) 165 166 stateful entrypoint finalize_epoch(epoch_number : int, fork : bytes(), pc_hash : bytes(), producer : address, votes : list(vote)) = 167 let epoch = state.epoch 168 let last = state.epochs[epoch].start + state.epochs[epoch].length - 1 169 require(Chain.block_height == last, "Only in last block") 170 require(Call.caller == state.leader, "Must be called by the last leader of epoch") 171 require(epoch_number == state.epoch, "Not correct epoch") 172 require(List.all(validate_vote, votes), "Invalid vote") 173 let ei_final = 174 switch(state.finalize) 175 None => {epoch = epoch_number, fork = Some(fork), epoch_length = None, pc_hash = pc_hash, producer = producer, votes = Some(votes), length_votes = None} 176 Some(finalize) => 177 if ( finalize.epoch == epoch_number ) 178 finalize{ fork = Some(fork), votes = Some(votes) } 179 else 180 {epoch = epoch_number, fork = Some(fork), epoch_length = None, pc_hash = pc_hash, producer = producer, votes = Some(votes), length_votes = None} 181 182 put(state{finalize=Some(ei_final)}) 183 184 185 stateful entrypoint finalize_epoch_length(epoch_number : int, epoch_length : int, pc_hash : bytes(), producer : address, votes : list(epoch_length_vote)) = 186 let epoch = state.epoch 187 let last = state.epochs[epoch].start + state.epochs[epoch].length - 1 188 require(Chain.block_height == last, "Only in last block") 189 require(Call.caller == state.leader, "Must be called by the last leader of epoch") 190 require(epoch_number == state.epoch, "Not correct epoch") 191 require(List.all(validate_length_vote, votes), "Invalid vote") 192 let ei_final = 193 switch(state.finalize) 194 None => {epoch = epoch_number, fork = None, epoch_length = Some(epoch_length), pc_hash = pc_hash, producer = producer, votes = None, length_votes = Some(votes)} 195 Some(finalize) => 196 if ( finalize.epoch == epoch_number ) 197 finalize{ epoch_length = Some(epoch_length), length_votes = Some(votes) } 198 else 199 {epoch = epoch_number, fork = None, epoch_length = Some(epoch_length), pc_hash = pc_hash, producer = producer, votes = None, length_votes = Some(votes)} 200 201 put(state{finalize=Some(ei_final)}) 202 203 entrypoint leader() = 204 state.leader 205 206 entrypoint epoch() = 207 state.epoch 208 209 entrypoint epoch_length() = 210 state.epochs[state.epoch].length 211 212 entrypoint epoch_info() = 213 (state.epoch, state.epochs[state.epoch]) 214 215 entrypoint epoch_info_epoch(epoch : int) = 216 require(epoch >= state.epoch - 1 && epoch =< state.epoch + 2, "Epoch not in scope") 217 state.epochs[epoch] 218 219 entrypoint staking_contract() = 220 state.main_staking_ct 221 222 entrypoint validator_schedule(seed : bytes(), validators : list(address * int), length : int) = 223 let total_stake = List.foldl((+), 0, List.map(Pair.snd, validators)) 224 // One extra hash operation to convert from bytes() to bytes(32)/hash 225 validator_schedule_(Crypto.blake2b(seed), (s) => Bytes.to_int(s) mod total_stake, validators, length, []) 226 227 entrypoint pin_info() = 228 state.pin 229 230 entrypoint pin_reward_info() = 231 state.pin_reward 232 233 entrypoint finalize_info() = 234 state.finalize 235 236 function 237 validator_schedule_(_, _, _, 0, schedule) = List.reverse(schedule) 238 validator_schedule_(seed0, rnd, validators, n, schedule) = 239 let seed = Crypto.blake2b(seed0) 240 let validator = pick_validator(rnd(seed), validators) 241 validator_schedule_(seed, rnd, validators, n - 1, validator :: schedule) 242 243 function 244 pick_validator(n, (validator, stake) :: _) | n < stake = validator 245 pick_validator(n, (_, stake) :: validators) = pick_validator(n - stake, validators) 246 247 function assert_protocol_call() = 248 require(Call.caller == Contract.creator, "Must be called by the protocol") 249 250 function calc_carry_over(carry_over_flag : bool) = 251 let ei = state.pin_reward 252 if (carry_over_flag) 253 ei.base + ei.carry_over 254 else 255 0 256 257 function calc_next_base_reward(next_base_pin_reward : int) = 258 if ( next_base_pin_reward >= 0 ) 259 next_base_pin_reward 260 else 261 state.pin_reward.base 262 263 stateful function pay_rewards(e : int) = 264 let ei = state.epochs[e] 265 let (rewards, tot) = pay_rewards_(ei.start, ei.length, {}, 0) 266 state.main_staking_ct.add_rewards(value = tot, e, rewards) 267 268 stateful function 269 pay_rewards_(_, 0, acc, tot) = (Map.to_list(acc), tot) 270 pay_rewards_(h, n, acc, tot) = 271 switch(Map.lookup(h, state.rewards)) 272 None => pay_rewards_(h + 1, n - 1, acc, tot) 273 Some(aas) => 274 let (acc1, tot1) = List.foldl(pay_reward_, (acc, tot), aas) 275 put(state{rewards @ r = Map.delete(h, r)}) 276 pay_rewards_(h + 1, n - 1, acc1, tot1) 277 278 function pay_reward_((acc, tot), (addr, amt)) = 279 (acc{[addr = 0] @ r = r + amt}, tot + amt) 280 281 stateful function collect_penalties(epoch : int) = 282 let ei = state.epochs[epoch] 283 let pens = collect_penalties_(ei.start, ei.length, {}) 284 state.main_staking_ct.add_penalties(epoch, List.sort(compare_pens_, pens)) // penalties first to fill pool, rewards second 285 286 stateful function 287 collect_penalties_(_, 0, acc) = Map.to_list(acc) 288 collect_penalties_(h, n, acc) = 289 switch(Map.lookup(h, state.penalties)) 290 None => collect_penalties_(h + 1, n - 1, acc) 291 Some(aas) => 292 let acc1 = List.foldl(collect_penalty_, acc, aas) 293 put(state{penalties @ r = Map.delete(h, r)}) 294 collect_penalties_(h + 1, n - 1, acc1) 295 296 function collect_penalty_(acc, (addr, amt)) = 297 acc{[addr = 0] @ r = r + amt} 298 299 // sort largest penalties first 300 function compare_pens_((_, x), (_, y)) = 301 x > y 302 303 function validate_vote(vote) = 304 if(Crypto.verify_sig(vote.sign_data, vote.producer, vote.signature)) 305 contains_bytes(Bytes.to_any_size(vote.hash), vote.sign_data) 306 else 307 false 308 309 function validate_length_vote(vote) = 310 if(Crypto.verify_sig(vote.sign_data, vote.producer, vote.signature)) 311 let needle = String.to_bytes(String.concats(["epoch_length_delta", Int.to_str(vote.epoch_delta)])) 312 contains_bytes(needle, vote.sign_data) 313 else 314 false 315 316 function contains_bytes(needle : bytes(), haystack : bytes()) : bool = 317 contains_bytes_(Bytes.size(needle), needle, haystack) 318 319 function contains_bytes_(len: int, needle : bytes(), haystack : bytes()) : bool = 320 if(Bytes.size(haystack) < len) 321 false 322 else 323 let Some((bs_, _)) = Bytes.split_any(haystack, len) 324 if(bs_ == needle) 325 true 326 else 327 let Some((_, haystack_)) = Bytes.split_any(haystack, 1) 328 contains_bytes_(len, needle, haystack_)