Skip to content

New method skipSorted(n) in StreamEx  #61

@amaembo

Description

@amaembo

Requested here by @AgnetaWalterscheidt:

I would like to implement a new method in StreamEx (or AbstractStreamEx):

public StreamEx<E> skipSorted(int n, Comparator<? super E> comparator);

The method skips the first n elements of the stream as if the stream was sorted using the comparator but without actually sorting the whole stream.
The method is similar to MoreCollectors.greatest(Comparator<? super T>, int) but

  • it uses the quickselect algorithm to do the partial sort and
  • it returns a stream instead of a Collector (which makes sense because skipping n elements can still return a stream with an infinite number of elements).

For symmetry I would like to implement a second method:

public StreamEx<E> limitSorted(int n, Comparator<? super E> comparator);

which is the counterpart to the first method because it limits the stream by returning only the first n elements of the stream (after "sorting" it like in the first method). This method obviously does not return a stream with an infinite number of elements (the stream will have n or less elements) but I think it would be confusing for the users of your library to look for the second method somewhere else.

I have already implemented a method skipSorted() that takes and returns an Iterator. You find the code below including a small test method (currently called from a main method).
I have also implemented (not included but I can send you the code if you are interested) a method skipSorted() using a PriorityQueue for sorting, and one using a TreeSet. The quickselect implementation is on average about factor 3 faster than the others.

The StreamEx method could be implemented as a small wrapper around these methods.

Please tell me if you would like me to implement this!

Cheers
Agneta

Code:

/**
 * Implementation details: a buffer of size 2*n is filled. Then this buffer is partially sorted with the quickselect algorithm
 * (see below). Then the elements from the right half of the buffer will be returned. When all these elements have been returned
 * then the buffer is filled again.
 * Additionally while filling the buffer elements are returned if they are greater than the nth smallest element seen so far. 
 * 
 * @param iterator        the iterator that returns the elements of which n elements are to be skipped
 * @param comparator      the comparator used for sorting
 * @param n               the number of elements to skip
 * @return                an iterator that will return the rest of the elements
 */
public static <E> Iterator<E> skipSorted(Iterator<E> iterator, Comparator<? super E> comparator, int n) {

    Iterator<E> result = precheck(iterator, comparator, n);

    if (result == null) {

        @SuppressWarnings("unchecked") // this is wrong the error will not be noticed (only E's inside)
        E[] buffer = (E[]) new Object[n * 2];

        E nthSmallest = iterator.next(); // when we have seen n elements this is the maximum (the nth smallest element seen so far); 
        // this means that when we see more elements we can already deliver the ones that are greater (because obviously they cannot 
        // be among the n smallest)

        buffer[0] = nthSmallest; 
        int bufferIndex = 1;

        while (bufferIndex < n && iterator.hasNext()) {
            E e = iterator.next();
            buffer[bufferIndex++] = e;
            if (comparator.compare(e, nthSmallest) > 0) {
                nthSmallest = e;
            }
        }

        if (!iterator.hasNext()) { // not enough elements
            result = iterator; 
        }

        result = new SkipSortedIterator<>(iterator, n, buffer, comparator, nthSmallest);
    }   

    return result;
}

private static <E> Iterator<E> precheck(Iterator<E> iterator, Comparator<? super E> comparator, int n) {
    Objects.requireNonNull(iterator);

    if (n < 0) {
        throw new IllegalArgumentException("n must not be negative but was: " + n);
    }

    if (n == 0 || !iterator.hasNext()) { // nothing to do

        return iterator;

    } else if (n >= Integer.MAX_VALUE / 2) {

        // if n is very large then sort all elements
        Iterable<E> iterable = () -> iterator;
        return StreamSupport.stream(iterable.spliterator(), false)
                .sorted(comparator)
                .skip(n)
                .iterator();

    } else {

        return null;
    }
}

private static class SkipSortedIterator<E> implements Iterator<E> {
    private E toReturn; // will contain an element that can be returned without sorting the buffer
    private int beyondBuffer; // necessary for the last run (when the iterator is done) to mark the end of the elements in the buffer
    private int bufferIndex; // index when filling the buffer with the next elements or returning elements from the buffer
    private boolean returningFromTheBuffer = false;
    private E nthSmallest;
    private final int n;
    private final E[] buffer;
    private final Comparator<? super E> comparator;
    private final Random random = new Random(1);
    private Iterator<E> iterator;

    private SkipSortedIterator(Iterator<E> iterator, int n, E[] buffer, Comparator<? super E> comparator, E nthSmallest) {
        this.iterator = iterator;
        this.n = this.bufferIndex = n;
        this.buffer = buffer;
        this.comparator = comparator;
        this.nthSmallest = nthSmallest;
        this.beyondBuffer = buffer.length;
    }

    @Override
    public boolean hasNext() {
        if (toReturn != null) {
            return true;
        }
        if (returningFromTheBuffer && (bufferIndex < beyondBuffer)) { // there are elements in the buffer that can be returned
            return true;
        }
        while (iterator.hasNext()) {
            E e = iterator.next();
            if (comparator.compare(e, nthSmallest) >= 0) {
                toReturn = e; 
                return true;
            }
            buffer[bufferIndex++] = e;
            if (bufferIndex == buffer.length) {
                // now we sort the elements of the buffer: all elements starting with the nth one will not be the n smallest elements
                quickselect_n();

                bufferIndex = n;
                returningFromTheBuffer = true;
                nthSmallest = buffer[n - 1];

                return true;
            }
        }
        // the iterator has reached its end, maybe sort again
        if (bufferIndex > n) {
            beyondBuffer = bufferIndex;
            quickselect_n();

            bufferIndex = n;
            returningFromTheBuffer = true;

            return true;
        }
        return false;
    }

    /**
     * We are using quickselect to partially sort the buffer. Quickselect works like this: choose an arbitrary array element
     * and iterate over all elements of the array and 
     * - leave the inspected element where it is when it is greater than or equal to the chosen value
     * - move the inspected element to a new position at the left of the array when it is smaller than the chosen value (the 
     *   "store index" used will then be incremented)
     * When all elements have been inspected then the chosen value will be put at the "store index" and the store index will be 
     * returned.
     * The buffer is now partially sorted: below the "store index" all values are lower than the chosen value, beyond the "store index"
     * all values are greater than or equal to the chosen value.
     * Now if the "store index" is n then we are done. Otherwise we have to "sort" again but only left to the "store index" or only 
     * right to the "store index". 
     * Pseudocode for quickselect from wikipedia (pseudocode for function partition() see below):
     * function select(list, left, right, n)
     *    loop
     *       if left = right
     *          return list[left]
     *       pivotIndex := ...     // select pivotIndex between left and right
     *       pivotIndex := partition(list, left, right, pivotIndex)
     *       if n = pivotIndex
     *          return list[n]
     *       else if n < pivotIndex
     *          right := pivotIndex - 1
     *       else
     *          left := pivotIndex + 1
     */
    void quickselect_n() {
        int left = 0;
        int right = beyondBuffer - 1;

        while (left < right) {
            int pivotIndex = left + random.nextInt(right - left + 1);

            int pivotNewIndex = partition(left, right, pivotIndex);
            if (pivotNewIndex == n - 1) {
                break;
            } else if (n - 1 < pivotNewIndex) {
                right = pivotNewIndex - 1;
            } else {
                left = pivotNewIndex + 1;
            } 
        }
    }

    /**
     * Pseudocode for function partition() from wikipedia:
     * function partition(list, left, right, pivotIndex)
     *    pivotValue := list[pivotIndex]
     *    swap list[pivotIndex] and list[right]  // Move pivot to end
     *    storeIndex := left
     *    for i from left to right-1
     *       if list[i] < pivotValue
     *          swap list[storeIndex] and list[i]
     *          increment storeIndex
     *    swap list[right] and list[storeIndex]  // Move pivot to its final place
     *    return storeIndex
     */
    private int partition(int left, int right, int pivotIndex) {
        // 1. the chosen value is the one at the pivotIndex
        E pivotValue = buffer[pivotIndex];

        // 2. the buffer elemnt at the pivotIndex is set to the value of the rightmost array element (this way the rightmost element does not need
        //    to be inspected
        buffer[pivotIndex] = buffer[right];

        // 3. iterate from left to right and keep the value where it is if it is smaller than the chosen value
        int storeIndex = left;
        int i = left;
        for (; i < right; i++) {
            if (comparator.compare(buffer[i], pivotValue) < 0) {
                storeIndex++;
            } else {
                break;
            }
        }
        // 4. continue iterating but now move the value to the left if it is smaller than the chosen value
        for (i++; i < right; i++) {
            if (comparator.compare(buffer[i], pivotValue) < 0) {
                E tmp = buffer[storeIndex];
                buffer[storeIndex] = buffer[i];
                buffer[i] = tmp;
                storeIndex++;
            }
        }

        // 5. move the chosen value to the store index
        buffer[right] = buffer[storeIndex];
        buffer[storeIndex] = pivotValue;
        return storeIndex;
    }   

    @Override
    public E next() {
        if (hasNext()) {
            if (returningFromTheBuffer && bufferIndex < beyondBuffer) {
                E tmp = buffer[bufferIndex];
                if (++bufferIndex == beyondBuffer) {
                    bufferIndex = n;
                    returningFromTheBuffer = false;
                }
                return tmp;
            }
            E tmp = toReturn;
            toReturn = null;
            return tmp;
        } else {
            throw new NoSuchElementException();
        }
    }
}









public static void test(int k, int factor) {
    int bufferCap = factor * k;
    int[] values = new int[bufferCap];
    Map<Integer, Integer> map = new TreeMap<>();
    Random random = new Random();
    int j = 0;
    while (j < bufferCap) {
        int value = random.nextInt();
        Integer oldCount = map.put(value, 1);
        if (oldCount != null) {
            map.put(value, oldCount + 1);
        }
        values[j++] = value;
    }
    int value = 0;
    int count = 0;
    for (Entry<Integer, Integer> entry : map.entrySet()) {
        count += entry.getValue();
        if (count >= k) {
            value = entry.getKey();
            break;
        }
    }

    Iterator<Integer> iterator = new Iterator<Integer>() {

        int index = 0;
        @Override
        public boolean hasNext() {
            return (index < bufferCap ? true : false);
        }

        @Override
        public Integer next() {
            return values[index++];
        }
    };

    Iterator<Integer> x = skipSorted(iterator, Comparator.naturalOrder(), k);

    count = 0;
    while (x.hasNext()) {
        Integer num = x.next();
        count++;
        if (num < value) {
            throw new RuntimeException();
        }
    }
    if (count != bufferCap - k) {
        throw new RuntimeException();
    }

}

public static void main(String[] args) {
    Random random = new Random();
    for (int i = 0; i < 1000; i++) {
        int k = 50 + random.nextInt(1000);
        int size = 2 + random.nextInt(10);
        test(k, size);
    }
}

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions