/ coreblocks / func_blocks / fu / lsu / lsu_requester.py
lsu_requester.py
  1  from amaranth import *
  2  from amaranth_types import ModuleLike
  3  from transactron import Method, def_method, TModule
  4  from transactron.lib.simultaneous import condition
  5  from transactron.lib.logging import HardwareLogger
  6  from transactron.lib import BasicFifo
  7  
  8  from coreblocks.params import *
  9  from coreblocks.arch import Funct3, ExceptionCause
 10  from coreblocks.peripherals.bus_adapter import BusMasterInterface
 11  from coreblocks.interface.layouts import CommonLayoutFields, LSULayouts
 12  
 13  
 14  class LSURequester(Elaboratable):
 15      """
 16      Bus request logic for the load/store unit. Its job is to interface
 17      between the LSU and the bus.
 18  
 19      Attributes
 20      ----------
 21      issue : Method
 22          Issues a new request to the bus.
 23      accept : Method
 24          Retrieves a result from the bus.
 25      """
 26  
 27      def __init__(self, gen_params: GenParams, bus: BusMasterInterface, depth: int = 4) -> None:
 28          """
 29          Parameters
 30          ----------
 31          gen_params : GenParams
 32              Parameters to be used during processor generation.
 33          bus : BusMasterInterface
 34              An instance of the bus master for interfacing with the data bus.
 35          depth : int
 36              Number of requests which can be send to memory, before it provides first response. Describe
 37              the resiliency of `LSURequester` to latency of memory in case when memory is fully pipelined.
 38          """
 39          self.gen_params = gen_params
 40          self.bus = bus
 41          self.depth = depth
 42  
 43          lsu_layouts = gen_params.get(LSULayouts)
 44  
 45          self.issue = Method(i=lsu_layouts.issue, o=lsu_layouts.issue_out)
 46          self.accept = Method(o=lsu_layouts.accept)
 47  
 48          self.log = HardwareLogger("backend.lsu.requester")
 49  
 50      def prepare_bytes_mask(self, m: ModuleLike, funct3: Value, addr: Value) -> Signal:
 51          mask_len = self.gen_params.isa.xlen // self.bus.params.granularity
 52          mask = Signal(mask_len)
 53          with m.Switch(funct3):
 54              with m.Case(Funct3.B, Funct3.BU):
 55                  m.d.av_comb += mask.eq(0x1 << addr[0:2])
 56              with m.Case(Funct3.H, Funct3.HU):
 57                  m.d.av_comb += mask.eq(0x3 << (addr[1] << 1))
 58              with m.Case(Funct3.W):
 59                  m.d.av_comb += mask.eq(0xF)
 60          return mask
 61  
 62      def postprocess_load_data(self, m: ModuleLike, funct3: Value, raw_data: Value, addr: Value):
 63          data = Signal.like(raw_data)
 64          with m.Switch(funct3):
 65              with m.Case(Funct3.B, Funct3.BU):
 66                  tmp = Signal(8)
 67                  m.d.av_comb += tmp.eq((raw_data >> (addr[0:2] << 3)) & 0xFF)
 68                  with m.If(funct3 == Funct3.B):
 69                      m.d.av_comb += data.eq(tmp.as_signed())
 70                  with m.Else():
 71                      m.d.av_comb += data.eq(tmp)
 72              with m.Case(Funct3.H, Funct3.HU):
 73                  tmp = Signal(16)
 74                  m.d.av_comb += tmp.eq((raw_data >> (addr[1] << 4)) & 0xFFFF)
 75                  with m.If(funct3 == Funct3.H):
 76                      m.d.av_comb += data.eq(tmp.as_signed())
 77                  with m.Else():
 78                      m.d.av_comb += data.eq(tmp)
 79              with m.Default():
 80                  m.d.av_comb += data.eq(raw_data)
 81          return data
 82  
 83      def prepare_data_to_save(self, m: ModuleLike, funct3: Value, raw_data: Value, addr: Value):
 84          data = Signal.like(raw_data)
 85          with m.Switch(funct3):
 86              with m.Case(Funct3.B):
 87                  m.d.av_comb += data.eq(raw_data[0:8] << (addr[0:2] << 3))
 88              with m.Case(Funct3.H):
 89                  m.d.av_comb += data.eq(raw_data[0:16] << (addr[1] << 4))
 90              with m.Default():
 91                  m.d.av_comb += data.eq(raw_data)
 92          return data
 93  
 94      def check_align(self, m: TModule, funct3: Value, addr: Value):
 95          aligned = Signal()
 96          with m.Switch(funct3):
 97              with m.Case(Funct3.W):
 98                  m.d.av_comb += aligned.eq(addr[0:2] == 0)
 99              with m.Case(Funct3.H, Funct3.HU):
100                  m.d.av_comb += aligned.eq(addr[0] == 0)
101              with m.Default():
102                  m.d.av_comb += aligned.eq(1)
103          return aligned
104  
105      def elaborate(self, platform):
106          m = TModule()
107  
108          layouts = self.gen_params.get(CommonLayoutFields)
109          m.submodules.args_fifo = args_fifo = BasicFifo(
110              [
111                  layouts.vaddr,
112                  layouts.paddr,
113                  ("funct3", Funct3),
114                  ("store", 1),
115              ],
116              self.depth,
117          )
118  
119          @def_method(m, self.issue)
120          def _(paddr: Value, vaddr: Value, data: Value, funct3: Value, store: Value):
121              exception = Signal()
122              cause = Signal(ExceptionCause)
123  
124              aligned = self.check_align(m, funct3, paddr)
125              bytes_mask = self.prepare_bytes_mask(m, funct3, paddr)
126              bus_data = self.prepare_data_to_save(m, funct3, data, paddr)
127  
128              self.log.debug(
129                  m,
130                  1,
131                  "issue addr=0x{:08x} data=0x{:08x} funct3={} store={} aligned={}",
132                  vaddr,
133                  data,
134                  funct3,
135                  store,
136                  aligned,
137              )
138  
139              with condition(m, nonblocking=True) as branch:
140                  with branch(aligned & store):
141                      self.bus.request_write(m, addr=paddr >> 2, data=bus_data, sel=bytes_mask)
142                  with branch(aligned & ~store):
143                      self.bus.request_read(m, addr=paddr >> 2, sel=bytes_mask)
144  
145              with m.If(aligned):
146                  args_fifo.write(m, paddr=paddr, vaddr=vaddr, funct3=funct3, store=store)
147              with m.Else():
148                  m.d.av_comb += exception.eq(1)
149                  m.d.av_comb += cause.eq(
150                      Mux(store, ExceptionCause.STORE_ADDRESS_MISALIGNED, ExceptionCause.LOAD_ADDRESS_MISALIGNED)
151                  )
152  
153              return {"exception": exception, "cause": cause}
154  
155          @def_method(m, self.accept)
156          def _():
157              data = Signal(self.gen_params.isa.xlen)
158              exception = Signal()
159              cause = Signal(ExceptionCause)
160              err = Signal()
161  
162              request_args = args_fifo.read(m)
163              self.log.debug(m, 1, "accept data=0x{:08x} exception={} cause={}", data, exception, cause)
164  
165              with condition(m) as branch:
166                  with branch(request_args.store):
167                      fetched = self.bus.get_write_response(m)
168                      m.d.comb += err.eq(fetched.err)
169                  with branch():
170                      fetched = self.bus.get_read_response(m)
171                      m.d.comb += err.eq(fetched.err)
172                      m.d.top_comb += data.eq(
173                          self.postprocess_load_data(m, request_args.funct3, fetched.data, request_args.paddr)
174                      )
175  
176              with m.If(err):
177                  m.d.av_comb += exception.eq(1)
178                  m.d.av_comb += cause.eq(
179                      Mux(request_args.store, ExceptionCause.STORE_ACCESS_FAULT, ExceptionCause.LOAD_ACCESS_FAULT)
180                  )
181  
182              return {"data": data, "exception": exception, "cause": cause, "addr": request_args.vaddr}
183  
184          return m