/ test / contracts / HCElection.aes
HCElection.aes
  1  include "List.aes"
  2  include "Pair.aes"
  3  include "String.aes"
  4  
  5  contract interface MainStaking =
  6    entrypoint sorted_validators : () => list((address * int))
  7    entrypoint lock_stake  : (int) => list((address * int))
  8    entrypoint add_rewards : (int, list(address * int)) => unit
  9    entrypoint add_penalties : (int, list(address * int)) => unit
 10  
 11  main contract HCElection =
 12    record epoch_info =
 13      { start                 : int,
 14        length                : int,
 15        seed                  : option(bytes()),
 16        staking_distribution  : option(list(address * int))
 17      }
 18  
 19    record pin_reward_info =
 20      { base       : int,
 21        current    : int,
 22        carry_over : int
 23      }
 24  
 25    record vote =
 26      { producer    : address,
 27        hash        : hash,
 28        sign_data   : bytes(),
 29        signature   : signature
 30      }
 31  
 32    record epoch_length_vote =
 33      { producer    : address,
 34        epoch_delta : int,
 35        sign_data   : bytes(),
 36        signature   : signature
 37      }
 38  
 39    record finalize_info =
 40      { epoch        : int,
 41        fork         : option(bytes()),
 42        epoch_length : option(int),
 43        pc_hash      : bytes(),
 44        producer     : address,
 45        votes        : option(list(vote)),
 46        length_votes : option(list(epoch_length_vote))
 47      }
 48  
 49    record state =
 50      { main_staking_ct       : MainStaking,
 51        leader                : address,
 52        rewards               : map(int, list(address * int)),
 53        penalties             : map(int, list(address * int)),
 54        epoch                 : int,
 55        epochs                : map(int, epoch_info),
 56        pin                   : option(bytes()),
 57        pin_reward            : pin_reward_info,
 58        finalize              : option(finalize_info())
 59      }
 60  
 61    entrypoint init(main_staking_ct : MainStaking) =
 62      { main_staking_ct       = main_staking_ct,
 63        leader                = Contract.address,
 64        rewards               = {},
 65        epoch                 = 0,
 66        epochs                = {},
 67        pin                   = None,
 68        finalize              = None,
 69        penalties             = {},
 70        pin_reward            = {base = 0, current = 0, carry_over = 0} }
 71  
 72    stateful entrypoint init_epochs(epoch_length : int, base_pin_reward : int) =
 73      assert_protocol_call()
 74      require(Chain.block_height == 0, "Only in genesis")
 75      put(state{ epochs = { [0] = mk_epoch_info(0, 1, None, None),
 76                            [1] = mk_epoch_info(1, epoch_length, None, Some(state.main_staking_ct.lock_stake(1))),
 77                            [2] = mk_epoch_info(epoch_length + 1, epoch_length, None, Some(state.main_staking_ct.lock_stake(2))),
 78                            [3] = mk_epoch_info(2 * epoch_length + 1, epoch_length, None, Some(state.main_staking_ct.lock_stake(3))),
 79                            [4] = mk_epoch_info(3 * epoch_length + 1, epoch_length, None, Some(state.main_staking_ct.lock_stake(4)))
 80                          },
 81                 epoch  = 1,
 82                 pin_reward.base = base_pin_reward
 83               })
 84  
 85    function mk_epoch_info(start : int,
 86                           length : int,
 87                           seed : option(bytes),
 88                           staking_distribution : option(list(address * int))) =
 89      {start = start, length = length, seed = seed, staking_distribution = staking_distribution}
 90  
 91    stateful entrypoint step(leader : address) =
 92      assert_protocol_call()
 93      put(state{ leader = leader })
 94  
 95    stateful entrypoint step_micro(leader : address) =
 96      assert_protocol_call()
 97      put(state{ leader = leader })
 98  
 99    stateful entrypoint step_eoe(leader : address, seed : bytes(),
100                                 next_base_pin_reward : int, carry_over_flag : bool) =
101      assert_protocol_call()
102      let epoch = state.epoch
103      let ei = state.epochs[epoch]
104      // pin rewards
105      let next_carry_over = calc_carry_over(carry_over_flag)
106      let next_base = calc_next_base_reward(next_base_pin_reward)
107      let next_reward = next_base + next_carry_over
108      let pr = {current = next_reward, carry_over = next_carry_over, base = next_base}
109  
110      collect_penalties(state.epoch - 1)
111      // pay rewards to sunset epoch
112      pay_rewards(state.epoch - 1)
113  
114      // update epochs
115      require(ei.start + ei.length - 1 == Chain.block_height, String.concats(["This is not the end: ", Int.to_str(ei.start + ei.length - 1), " : ", Int.to_str(Chain.block_height)]))
116      let ei_adjust =
117        switch(state.finalize)
118          None       => state.epochs[epoch + 3]
119          Some(info) =>
120            if ( info.epoch == epoch )
121              switch(info.epoch_length)
122                None => state.epochs[epoch + 3]
123                Some(epoch_length) =>
124                  state.epochs[epoch + 3]{ length = epoch_length }
125            else
126              state.epochs[epoch + 3]
127  
128      let new_epochs = { [epoch] = state.epochs[epoch],
129                         [epoch + 1] = state.epochs[epoch + 1],
130                         [epoch + 2] = state.epochs[epoch + 2]{ seed = Some(seed) },
131                         [epoch + 3] = ei_adjust,
132                         [epoch + 4] = mk_epoch_info(ei_adjust.start + ei_adjust.length, ei_adjust.length, None,
133                                                     Some(state.main_staking_ct.lock_stake(epoch + 4)))
134                       }
135  
136      put(state{ leader = leader,
137                 epoch  = epoch + 1,
138                 epochs = new_epochs,
139                 pin = None,
140                 pin_reward = pr})
141  
142    stateful entrypoint pin(proof : bytes()) =
143      let epoch = state.epoch
144      let last = state.epochs[epoch].start + state.epochs[epoch].length - 1
145      require(Chain.block_height == last, "Only in last block")
146      require(Call.caller == state.leader, "Must be called by the last leader of epoch")
147      put(state{pin = Some(proof)})
148  
149    payable stateful entrypoint add_reward(height : int, to : address) =
150      assert_protocol_call()
151      put(state{rewards[height = []] @ rs = (to, Call.value) :: rs})
152  
153    stateful entrypoint add_penalty(height : int, penalty : int, to : address) =
154        assert_protocol_call()
155        require(penalty > 0, "Penalty is expressed as a positive integer")
156        put(state{penalties[height = []] @ rs = (to, penalty) :: rs})
157  
158    stateful entrypoint add_reported_penalty(height : int, penalty : int, to : address, reporter : address, reporter_percentage : int) =
159        assert_protocol_call()
160        require(penalty > 0, "Penalty is expressed as a positive integer")
161        require(reporter_percentage >= 0 && reporter_percentage =< 100, "reporter reward percentage must be between 0 and 100")
162        put(state{penalties[height = []] @ rs = (to, (penalty)) :: rs})
163        let reporter_reward = (penalty * reporter_percentage) / 100
164        put(state{penalties[height = []] @ rs = (reporter, -reporter_reward) :: rs})
165  
166    stateful entrypoint finalize_epoch(epoch_number : int, fork : bytes(), pc_hash : bytes(), producer : address, votes : list(vote)) =
167      let epoch = state.epoch
168      let last = state.epochs[epoch].start + state.epochs[epoch].length - 1
169      require(Chain.block_height == last, "Only in last block")
170      require(Call.caller == state.leader, "Must be called by the last leader of epoch")
171      require(epoch_number == state.epoch, "Not correct epoch")
172      require(List.all(validate_vote, votes), "Invalid vote")
173      let ei_final =
174        switch(state.finalize)
175          None       => {epoch = epoch_number, fork = Some(fork), epoch_length = None, pc_hash = pc_hash, producer = producer, votes = Some(votes), length_votes = None}
176          Some(finalize) =>
177            if ( finalize.epoch == epoch_number )
178              finalize{ fork = Some(fork), votes = Some(votes) }
179            else
180              {epoch = epoch_number, fork = Some(fork), epoch_length = None, pc_hash = pc_hash, producer = producer, votes = Some(votes), length_votes = None}
181  
182      put(state{finalize=Some(ei_final)})
183  
184  
185    stateful entrypoint finalize_epoch_length(epoch_number : int, epoch_length : int, pc_hash : bytes(), producer : address, votes : list(epoch_length_vote)) =
186      let epoch = state.epoch
187      let last = state.epochs[epoch].start + state.epochs[epoch].length - 1
188      require(Chain.block_height == last, "Only in last block")
189      require(Call.caller == state.leader, "Must be called by the last leader of epoch")
190      require(epoch_number == state.epoch, "Not correct epoch")
191      require(List.all(validate_length_vote, votes), "Invalid vote")
192      let ei_final =
193        switch(state.finalize)
194          None       => {epoch = epoch_number, fork = None, epoch_length = Some(epoch_length), pc_hash = pc_hash, producer = producer, votes = None, length_votes = Some(votes)}
195          Some(finalize) =>
196            if ( finalize.epoch == epoch_number )
197              finalize{ epoch_length = Some(epoch_length), length_votes = Some(votes) }
198            else
199              {epoch = epoch_number, fork = None, epoch_length = Some(epoch_length), pc_hash = pc_hash, producer = producer, votes = None, length_votes = Some(votes)}
200  
201      put(state{finalize=Some(ei_final)})
202  
203    entrypoint leader() =
204      state.leader
205  
206    entrypoint epoch() =
207      state.epoch
208  
209    entrypoint epoch_length() =
210      state.epochs[state.epoch].length
211  
212    entrypoint epoch_info() =
213      (state.epoch, state.epochs[state.epoch])
214  
215    entrypoint epoch_info_epoch(epoch : int) =
216      require(epoch >= state.epoch - 1 && epoch =< state.epoch + 2, "Epoch not in scope")
217      state.epochs[epoch]
218  
219    entrypoint staking_contract() =
220      state.main_staking_ct
221  
222    entrypoint validator_schedule(seed : bytes(), validators : list(address * int), length : int) =
223      let total_stake = List.foldl((+), 0, List.map(Pair.snd, validators))
224      // One extra hash operation to convert from bytes() to bytes(32)/hash
225      validator_schedule_(Crypto.blake2b(seed), (s) => Bytes.to_int(s) mod total_stake, validators, length, [])
226  
227    entrypoint pin_info() =
228      state.pin
229  
230    entrypoint pin_reward_info() =
231      state.pin_reward
232  
233    entrypoint finalize_info() =
234      state.finalize
235  
236    function
237      validator_schedule_(_, _, _, 0, schedule) = List.reverse(schedule)
238      validator_schedule_(seed0, rnd, validators, n, schedule) =
239        let seed = Crypto.blake2b(seed0)
240        let validator = pick_validator(rnd(seed), validators)
241        validator_schedule_(seed, rnd, validators, n - 1, validator :: schedule)
242  
243    function
244      pick_validator(n, (validator, stake) :: _) | n < stake = validator
245      pick_validator(n, (_, stake) :: validators)            = pick_validator(n - stake, validators)
246  
247    function assert_protocol_call() =
248      require(Call.caller == Contract.creator, "Must be called by the protocol")
249  
250    function calc_carry_over(carry_over_flag : bool) =
251      let ei = state.pin_reward
252      if (carry_over_flag)
253        ei.base + ei.carry_over
254      else
255        0
256  
257    function calc_next_base_reward(next_base_pin_reward : int) =
258      if ( next_base_pin_reward >= 0 )
259        next_base_pin_reward
260      else
261        state.pin_reward.base
262  
263    stateful function pay_rewards(e : int) =
264      let ei = state.epochs[e]
265      let (rewards, tot) = pay_rewards_(ei.start, ei.length, {}, 0)
266      state.main_staking_ct.add_rewards(value = tot, e, rewards)
267  
268    stateful function
269      pay_rewards_(_, 0, acc, tot) = (Map.to_list(acc), tot)
270      pay_rewards_(h, n, acc, tot) =
271        switch(Map.lookup(h, state.rewards))
272          None      => pay_rewards_(h + 1, n - 1, acc, tot)
273          Some(aas) =>
274            let (acc1, tot1) = List.foldl(pay_reward_, (acc, tot), aas)
275            put(state{rewards @ r = Map.delete(h, r)})
276            pay_rewards_(h + 1, n - 1, acc1, tot1)
277  
278    function pay_reward_((acc, tot), (addr, amt)) =
279      (acc{[addr = 0] @ r = r + amt}, tot + amt)
280  
281    stateful function collect_penalties(epoch : int) =
282      let ei = state.epochs[epoch]
283      let pens = collect_penalties_(ei.start, ei.length, {})
284      state.main_staking_ct.add_penalties(epoch, List.sort(compare_pens_, pens)) // penalties first to fill pool, rewards second
285  
286    stateful function
287      collect_penalties_(_, 0, acc) = Map.to_list(acc)
288      collect_penalties_(h, n, acc) =
289        switch(Map.lookup(h, state.penalties))
290          None      => collect_penalties_(h + 1, n - 1, acc)
291          Some(aas) =>
292            let acc1 = List.foldl(collect_penalty_, acc, aas)
293            put(state{penalties @ r = Map.delete(h, r)})
294            collect_penalties_(h + 1, n - 1, acc1)
295  
296    function collect_penalty_(acc, (addr, amt)) =
297      acc{[addr = 0] @ r = r + amt}
298  
299    // sort largest penalties first
300    function compare_pens_((_, x), (_, y)) =
301      x > y
302  
303    function validate_vote(vote) =
304      if(Crypto.verify_sig(vote.sign_data, vote.producer, vote.signature))
305        contains_bytes(Bytes.to_any_size(vote.hash), vote.sign_data)
306      else
307        false
308  
309    function validate_length_vote(vote) =
310      if(Crypto.verify_sig(vote.sign_data, vote.producer, vote.signature))
311        let needle = String.to_bytes(String.concats(["epoch_length_delta", Int.to_str(vote.epoch_delta)]))
312        contains_bytes(needle, vote.sign_data)
313      else
314        false
315  
316    function contains_bytes(needle : bytes(), haystack : bytes()) : bool =
317      contains_bytes_(Bytes.size(needle), needle, haystack)
318  
319    function contains_bytes_(len: int, needle : bytes(), haystack : bytes()) : bool =
320      if(Bytes.size(haystack) < len)
321        false
322      else
323        let Some((bs_, _)) = Bytes.split_any(haystack, len)
324        if(bs_ == needle)
325          true
326        else
327          let Some((_, haystack_)) = Bytes.split_any(haystack, 1)
328          contains_bytes_(len, needle, haystack_)