/ Misc / src / sortedarray.jl
sortedarray.jl
  1  import Serialization: serialize, deserialize, AbstractSerializer, serialize_type
  2  
  3  _dosort!(arr::AbstractVector, args...; dims=1, kwargs...) = sort!(arr, args...; kwargs...)
  4  _dosort!(arr, args...; kwargs...) = sort!(arr, args...; kwargs...)
  5  
  6  struct SortedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
  7      arr::A
  8      opts::NamedTuple{(:dims, :rev, :by),Tuple{Int,Bool,Function}}
  9      function SortedArray(
 10          arr::A=Vector[]; dims=1, rev=false, by=identity
 11      ) where {A<:AbstractArray}
 12          new{eltype(A),ndims(A),A}(_dosort!(arr; dims, rev, by), (; dims, rev, by))
 13      end
 14  end
 15  
 16  function Base.setindex!(::SortedArray, args...; kwargs...)
 17      error("setindex! not allowed for a SortedArray")
 18  end
 19  
 20  function Base.pushfirst!(::SortedArray, args...; kwargs...)
 21      error("pushfirst! not allowed for a SortedArray")
 22  end
 23  
 24  function Base.insert!(::SortedArray, args...; kwargs...)
 25      error("insert! not allowed for a SortedArray")
 26  end
 27  
 28  function Base.permute!(::SortedArray)
 29      error("permute! not allowed for a SortedArray")
 30  end
 31  
 32  function Base.invpermute!(::SortedArray)
 33      error("invpermute! not allowed for a SortedArray")
 34  end
 35  
 36  function Base.getindex(sa::SortedArray, idx)
 37      sa.arr[idx]
 38  end
 39  
 40  function Base.get(sa::SortedArray, i::Integer, default)
 41      get(sa.arr, i, default)
 42  end
 43  
 44  function Base.popfirst!(sa::SortedArray, args...)
 45      popfirst!(sa.arr, args...)
 46  end
 47  
 48  function Base.pop!(sa::SortedArray, args...)
 49      pop!(sa.arr, args...)
 50  end
 51  
 52  function Base.popat!(sa::SortedArray, args...)
 53      popat!(sa.arr, args...)
 54  end
 55  
 56  function Base.splice!(sa::SortedArray, args...)
 57      splice!(sa.arr, args...)
 58  end
 59  
 60  function Base.deleteat!(sa::SortedArray, args...)
 61      deleteat!(sa.arr, args...)
 62  end
 63  
 64  Base.sort!(sa::SortedArray, args...; kwargs...) = sa
 65  
 66  function Base.searchsorted(sa::SortedArray, args...; kwargs...)
 67      searchsorted(sa.arr, args...; kwargs...)
 68  end
 69  function Base.searchsortedfirst(sa::SortedArray, args...; kwargs...)
 70      searchsortedfirst(sa.arr, args...; kwargs...)
 71  end
 72  function Base.searchsortedlast(sa::SortedArray, args...; kwargs...)
 73      searchsortedlast(sa.arr, args...; kwargs...)
 74  end
 75  
 76  Base.length(sa::SortedArray) = length(sa.arr)
 77  Base.iterate(sa::SortedArray, args...; kwargs...) = iterate(sa.arr, args...; kwargs...)
 78  function Base.copy(sa::SortedArray)
 79      SortedArray(copy(sa.arr); sa.opts...)
 80  end
 81  
 82  function Base.reverse(sa::SortedArray)
 83      SortedArray(reverse(sa.arr); sa.opts..., rev=!sa.opts.rev)
 84  end
 85  
 86  function Base.permutedims(sa::SortedArray)
 87      SortedArray(
 88          permutedims(sa.arr); dims=ifelse(sa.opts.dims == 1, 2, 1), sa.opts.rev, sa.opts.by
 89      )
 90  end
 91  
 92  function Base.push!(sa::SortedArray{T,1,A}, value) where {T,A<:AbstractVector{T}}
 93      index = searchsortedfirst(sa.arr, value; rev=sa.opts.rev, by=sa.opts.by)
 94      insert!(sa.arr, index, value)
 95      return sa
 96  end
 97  
 98  function Base.push!(sa::SortedArray{T,N,A}, value) where {T,N,A<:AbstractArray{T,N}}
 99      error("push! is only supported for 1-dimensional SortedArray")
100  end
101  
102  function Base.append!(sa::SortedArray{T,1,A}, values) where {T,A<:AbstractVector{T}}
103      append!(sa.arr, values)
104      _dosort!(sa.arr; sa.opts...)
105      return sa
106  end
107  
108  function Base.append!(sa::SortedArray{T,N,A}, values) where {T,N,A<:AbstractArray{T,N}}
109      error("append! is only supported for 1-dimensional SortedArray")
110  end
111  
112  function Base.vcat(sas::SortedArray{T,1,A}...) where {T,A<:AbstractVector{T}}
113      isempty(sas) && return SortedArray{T,1,A}()
114      new_arr = vcat([sa.arr for sa in sas]...)
115      opts = sas[1].opts
116      _dosort!(new_arr; opts.dims, opts.rev, opts.by)
117      return SortedArray(new_arr; opts...)
118  end
119  
120  function Base.hcat(sas::SortedArray{T,N,A}...) where {T,N,A<:AbstractArray{T,N}}
121      if any(sa -> sa.opts.dims != 1, sas)
122          error("hcat is only supported for SortedArray with sorting dimension 1")
123      end
124      new_arr = hcat([sa.arr for sa in sas]...)
125      opts = sas[1].opts
126      _dosort!(new_arr; dims=1, opts.rev, opts.by)
127      return SortedArray(new_arr; dims=1, opts.rev, opts.by)
128  end
129  
130  function Base.cat(
131      sas::SortedArray{T,N,A}...; dims::Integer
132  ) where {T,N,A<:AbstractArray{T,N}}
133      if any(sa -> sa.opts.dims != sas[1].opts.dims, sas)
134          error("All SortedArray instances must have the same sorting dimension")
135      end
136      new_arr = cat([sa.arr for sa in sas]...; dims=dims)
137      opts = sas[1].opts
138      new_dims = opts.dims + (dims <= opts.dims)
139      _dosort!(new_arr; dims=new_dims, opts.rev, opts.by)
140      return SortedArray(new_arr; dims=new_dims, opts.rev, opts.by)
141  end
142  
143  Base.size(sa::SortedArray) = size(sa.arr)
144  function Base.similar(sa::SortedArray, ::Type{S}, dims::Dims) where {S}
145      SortedArray(similar(sa.arr, S, dims); sa.opts...)
146  end
147  Base.empty!(sa::SortedArray) = (empty!(sa.arr); sa)
148  function Base.empty(sa::SortedArray)
149      SortedArray(empty(sa.arr); sa.opts...)
150  end
151  Base.isempty(sa::SortedArray) = isempty(sa.arr)
152  Base.firstindex(sa::SortedArray) = firstindex(sa.arr)
153  Base.lastindex(sa::SortedArray) = lastindex(sa.arr)
154  Base.axes(sa::SortedArray) = axes(sa.arr)
155  Base.eltype(::Type{SortedArray{A}}) where {A<:AbstractArray} = eltype(A)
156  Base.collect(sa::SortedArray) = collect(sa.arr)
157  
158  function Base.map(f, sa::SortedArray)
159      new_arr = map(f, sa.arr)
160      SortedArray(new_arr; sa.opts...)
161  end
162  
163  function Base.filter(f, sa::SortedArray)
164      new_arr = filter(f, sa.arr)
165      SortedArray(new_arr; sa.opts...)
166  end
167  
168  Base.reduce(f, sa::SortedArray; kwargs...) = reduce(f, sa.arr; kwargs...)
169  Base.foldl(f, sa::SortedArray; kwargs...) = foldl(f, sa.arr; kwargs...)
170  Base.any(f::Function, sa::SortedArray) = any(f, sa.arr)
171  Base.all(f::Function, sa::SortedArray) = Base.all(f, sa.arr)
172  Base.in(x, sa::SortedArray) = in(x, sa.arr)
173  
174  function Base.show(
175      io::IO, ::MIME"text/plain", sa::SortedArray{T,N,A}
176  ) where {T,N,A<:AbstractArray{T,N}}
177      print(io, "SortedArray{$T,$N}(")
178      print(io, sa.arr)
179      print(io, "; dims=$(sa.opts.dims), rev=$(sa.opts.rev), by=$(sa.opts.by))")
180  end
181  
182  function Base.map!(f, sa::SortedArray, src::AbstractArray...)
183      map!(f, sa.arr, src...)
184      _dosort!(sa.arr; sa.opts...)
185      return sa
186  end
187  
188  function Base.filter!(f, sa::SortedArray)
189      filter!(f, sa.arr)
190      _dosort!(sa.arr; sa.opts...)
191      return sa
192  end
193  
194  Base.view(sa::SortedArray, args...) = view(sa.arr, args...)
195  
196  function Base.resize!(sa::SortedArray, n)
197      resize!(sa.arr, n)
198      _dosort!(sa.arr; sa.opts...)
199      return sa
200  end
201  
202  function Base.unique!(sa::SortedArray)
203      unique!(sa.arr)
204      return sa
205  end
206  
207  function Base.unique(sa::SortedArray)
208      SortedArray(unique(sa.arr); sa.opts...)
209  end
210  
211  function Base.intersect(sa1::SortedArray, sa2::SortedArray)
212      SortedArray(intersect(sa1.arr, sa2.arr); sa1.opts...)
213  end
214  
215  function Base.union(sa1::SortedArray, sa2::SortedArray)
216      SortedArray(union(sa1.arr, sa2.arr); sa1.opts...)
217  end
218  
219  function Base.setdiff(sa1::SortedArray, sa2::SortedArray)
220      SortedArray(setdiff(sa1.arr, sa2.arr); sa1.opts...)
221  end
222  
223  function Base.findall(f::Function, sa::SortedArray)
224      SortedArray(findall(f, sa.arr); sa.opts...)
225  end
226  
227  function Base.findfirst(f::Function, sa::SortedArray)
228      findfirst(f, sa.arr)
229  end
230  
231  function Base.findlast(f::Function, sa::SortedArray)
232      findlast(f, sa.arr)
233  end
234  
235  Base.sum(sa::SortedArray) = sum(sa.arr)
236  Base.prod(sa::SortedArray) = prod(sa.arr)
237  Base.maximum(sa::SortedArray) = maximum(sa.arr)
238  Base.minimum(sa::SortedArray) = minimum(sa.arr)
239  
240  function Base.broadcast(f, sa::SortedArray, args...)
241      result = broadcast(f, sa.arr, args...)
242      SortedArray(_dosort!(result; sa.opts...); sa.opts...)
243  end
244  
245  Base.sort(sa::SortedArray) = copy(sa)
246  
247  Base.issorted(::SortedArray) = true
248  
249  function Base.merge(sa1::SortedArray, sa2::SortedArray)
250      SortedArray(merge(sa1.arr, sa2.arr); sa1.opts...)
251  end
252  
253  function Base.merge!(sa1::SortedArray, sa2::SortedArray)
254      merge!(sa1.arr, sa2.arr)
255      _dosort!(sa1.arr; sa1.opts...)
256      return sa1
257  end
258  
259  function Base.partialsort(sa::SortedArray, k; kwargs...)
260      SortedArray(partialsort(sa.arr, k; kwargs...); sa.opts...)
261  end
262  
263  function Base.partialsortperm(sa::SortedArray, k; kwargs...)
264      SortedArray(partialsortperm(sa.arr, k; kwargs...); sa.opts...)
265  end
266  
267  function Base.accumulate(f, sa::SortedArray; dims=sa.opts.dims)
268      result = accumulate(f, sa.arr; dims=dims)
269      SortedArray(result; sa.opts...)
270  end
271  
272  Base.cumsum(sa::SortedArray; dims=1) = accumulate(+, sa; dims=dims)
273  
274  function serialize(s::AbstractSerializer, sa::A) where {A<:SortedArray}
275      serialize_type(s, A, false)
276      serialize(s, (sa.arr, sa.opts))
277  end
278  
279  function deserialize(buf::AbstractSerializer, ::Type{<:SortedArray})
280      arr, opts = deserialize(buf)
281      SortedArray(arr; opts...)
282  end
283  
284  function Base.sizehint!(sa::SortedArray, v)
285      sizehint!(sa.arr, v)
286  end
287  
288  export SortedArray