Here’s another attempt with value types:
@inline dispatch_on_value(::Val{true}, first_old_index, tail_old_index, replace_index, should_replace) =
(replace_index[1], replace_tuple(tail_old_index, Base.tail(replace_index), should_replace )...)
@inline dispatch_on_value(::Val{false}, first_old_index, tail_old_index, replace_index, should_replace) =
(first_old_index, replace_tuple(tail_old_index, replace_index, should_replace )...)
@inline replace_tuple(old_index, replace_index::Tuple{}, should_replace) = old_index
@inline replace_tuple(old_index, replace_index, should_replace) =
dispatch_on_value(should_replace[1], old_index[1], Base.tail(old_index), replace_index, Base.tail(should_replace))
@inline function inner_mapslices!(f, input, result,
dimension_is_indexed, indexes, input_index, result_index,
safe_for_reuse, input_slice
)
if safe_for_reuse
# when f returns an input, result[result_index...] = f(input_slice) line copies elements,
# so we can reuse input_slice
for index in indexes
Base._unsafe_getindex!(input_slice, input, replace_tuple(input_index, index.I, dimension_is_indexed)...)
result[replace_tuple(result_index, index.I, dimension_is_indexed)...] = f(input_slice)
end
else
# we can't guarantee safety (#18524), so allocate new storage for each slice
for index in indexes
result[replace_tuple(result_index, index.I, dimension_is_indexed)...] =
f(input[replace_tuple(input_index, index.I, dimension_is_indexed)...])
end
end
result
end
@code_warntype inner_mapslices!(f, input, result,
(dimension_is_indexed...), indexes, (input_index...), (result_index...),
safe_for_reuse, input_slice
)
@noinline inner_mapslices_noinline!(f, input, result,
dimension_is_indexed, indexes, input_index, result_index,
safe_for_reuse, input_slice
) = inner_mapslices!(f, input, result,
dimension_is_indexed, indexes, input_index, result_index,
safe_for_reuse, input_slice
)
function my_mapslices(f, input::AbstractArray, sliced_dimensions::AbstractVector)
if isempty(sliced_dimensions)
return map(f,input)
end
axes = [indices(input)...]
rank = ndims(input)
input_index = Any[first(index) for index in indices(input)]
for dimension in sliced_dimensions
input_index[dimension] = Base.Slice(indices(input, dimension))
end
# Apply the function to the first slice in order to determine the next steps
input_slice = input[input_index...]
first_output = f(input_slice)
# In some cases, we can re-use the first slice for a dramatic performance
# increase. The slice itself must be mutable and the result cannot contain
# any mutable containers. The following errs on the side of being overly
# strict (#18570 & #21123).
safe_for_reuse =
isa(input_slice, StridedArray) &&
(isa(first_output, Number) ||
(isa(first_output, AbstractArray) && eltype(first_output) <: Number))
# determine result size and allocate
result_axes = copy(axes)
# TODO: maybe support removing dimensions
if !isa(first_output, AbstractArray) || ndims(first_output) == 0
first_output = [first_output]
end
number_of_trivial_output_axes =
max(0, length(sliced_dimensions) - ndims(first_output))
if eltype(result_axes) == Int
result_axes[sliced_dimensions] =
[size(first_output)...,
ntuple(dimension->1, number_of_trivial_output_axes)...]
else
result_axes[sliced_dimensions] =
[indices(first_output)...,
ntuple(dimension -> Base.OneTo(1), number_of_trivial_output_axes)...]
end
result = similar(first_output, tuple(result_axes...,))
result_index = Any[map(first, indices(result))...]
for dimension in sliced_dimensions
result_index[dimension] = indices(result, dimension)
end
result[result_index...] = first_output
# skip the first element, we already handled it
indexes = Iterators.drop(CartesianRange(tuple(axes[setdiff([1:rank;], sliced_dimensions)]...) ), 1)
dimension_is_indexed = map(1:rank) do dimension
if dimension in sliced_dimensions
Val{false}()
else
Val{true}()
end
end
inner_mapslices_noinline!(f, input, result,
(dimension_is_indexed...), indexes, (input_index...), (result_index...),
safe_for_reuse, input_slice
)
end
f = sum
input = rand(5, 5)
sliced_dimensions = [1]
using BenchmarkTools
Test.@test mapslices(f, input, sliced_dimensions) == my_mapslices(f, input, sliced_dimensions)
result1 = @benchmark mapslices(f, input, sliced_dimensions)
result2 = @benchmark my_mapslices(f, input, sliced_dimensions)
result3 = @benchmark sum(input, sliced_dimensions)