-
Notifications
You must be signed in to change notification settings - Fork 256
Description
Requested here by @AgnetaWalterscheidt:
I would like to implement a new method in StreamEx (or AbstractStreamEx):
public StreamEx<E> skipSorted(int n, Comparator<? super E> comparator);
The method skips the first n elements of the stream as if the stream was sorted using the comparator but without actually sorting the whole stream.
The method is similar to MoreCollectors.greatest(Comparator<? super T>, int) but
- it uses the quickselect algorithm to do the partial sort and
- it returns a stream instead of a Collector (which makes sense because skipping n elements can still return a stream with an infinite number of elements).
For symmetry I would like to implement a second method:
public StreamEx<E> limitSorted(int n, Comparator<? super E> comparator);
which is the counterpart to the first method because it limits the stream by returning only the first n elements of the stream (after "sorting" it like in the first method). This method obviously does not return a stream with an infinite number of elements (the stream will have n or less elements) but I think it would be confusing for the users of your library to look for the second method somewhere else.
I have already implemented a method skipSorted() that takes and returns an Iterator. You find the code below including a small test method (currently called from a main method).
I have also implemented (not included but I can send you the code if you are interested) a method skipSorted() using a PriorityQueue for sorting, and one using a TreeSet. The quickselect implementation is on average about factor 3 faster than the others.
The StreamEx method could be implemented as a small wrapper around these methods.
Please tell me if you would like me to implement this!
Cheers
Agneta
Code:
/**
* Implementation details: a buffer of size 2*n is filled. Then this buffer is partially sorted with the quickselect algorithm
* (see below). Then the elements from the right half of the buffer will be returned. When all these elements have been returned
* then the buffer is filled again.
* Additionally while filling the buffer elements are returned if they are greater than the nth smallest element seen so far.
*
* @param iterator the iterator that returns the elements of which n elements are to be skipped
* @param comparator the comparator used for sorting
* @param n the number of elements to skip
* @return an iterator that will return the rest of the elements
*/
public static <E> Iterator<E> skipSorted(Iterator<E> iterator, Comparator<? super E> comparator, int n) {
Iterator<E> result = precheck(iterator, comparator, n);
if (result == null) {
@SuppressWarnings("unchecked") // this is wrong the error will not be noticed (only E's inside)
E[] buffer = (E[]) new Object[n * 2];
E nthSmallest = iterator.next(); // when we have seen n elements this is the maximum (the nth smallest element seen so far);
// this means that when we see more elements we can already deliver the ones that are greater (because obviously they cannot
// be among the n smallest)
buffer[0] = nthSmallest;
int bufferIndex = 1;
while (bufferIndex < n && iterator.hasNext()) {
E e = iterator.next();
buffer[bufferIndex++] = e;
if (comparator.compare(e, nthSmallest) > 0) {
nthSmallest = e;
}
}
if (!iterator.hasNext()) { // not enough elements
result = iterator;
}
result = new SkipSortedIterator<>(iterator, n, buffer, comparator, nthSmallest);
}
return result;
}
private static <E> Iterator<E> precheck(Iterator<E> iterator, Comparator<? super E> comparator, int n) {
Objects.requireNonNull(iterator);
if (n < 0) {
throw new IllegalArgumentException("n must not be negative but was: " + n);
}
if (n == 0 || !iterator.hasNext()) { // nothing to do
return iterator;
} else if (n >= Integer.MAX_VALUE / 2) {
// if n is very large then sort all elements
Iterable<E> iterable = () -> iterator;
return StreamSupport.stream(iterable.spliterator(), false)
.sorted(comparator)
.skip(n)
.iterator();
} else {
return null;
}
}
private static class SkipSortedIterator<E> implements Iterator<E> {
private E toReturn; // will contain an element that can be returned without sorting the buffer
private int beyondBuffer; // necessary for the last run (when the iterator is done) to mark the end of the elements in the buffer
private int bufferIndex; // index when filling the buffer with the next elements or returning elements from the buffer
private boolean returningFromTheBuffer = false;
private E nthSmallest;
private final int n;
private final E[] buffer;
private final Comparator<? super E> comparator;
private final Random random = new Random(1);
private Iterator<E> iterator;
private SkipSortedIterator(Iterator<E> iterator, int n, E[] buffer, Comparator<? super E> comparator, E nthSmallest) {
this.iterator = iterator;
this.n = this.bufferIndex = n;
this.buffer = buffer;
this.comparator = comparator;
this.nthSmallest = nthSmallest;
this.beyondBuffer = buffer.length;
}
@Override
public boolean hasNext() {
if (toReturn != null) {
return true;
}
if (returningFromTheBuffer && (bufferIndex < beyondBuffer)) { // there are elements in the buffer that can be returned
return true;
}
while (iterator.hasNext()) {
E e = iterator.next();
if (comparator.compare(e, nthSmallest) >= 0) {
toReturn = e;
return true;
}
buffer[bufferIndex++] = e;
if (bufferIndex == buffer.length) {
// now we sort the elements of the buffer: all elements starting with the nth one will not be the n smallest elements
quickselect_n();
bufferIndex = n;
returningFromTheBuffer = true;
nthSmallest = buffer[n - 1];
return true;
}
}
// the iterator has reached its end, maybe sort again
if (bufferIndex > n) {
beyondBuffer = bufferIndex;
quickselect_n();
bufferIndex = n;
returningFromTheBuffer = true;
return true;
}
return false;
}
/**
* We are using quickselect to partially sort the buffer. Quickselect works like this: choose an arbitrary array element
* and iterate over all elements of the array and
* - leave the inspected element where it is when it is greater than or equal to the chosen value
* - move the inspected element to a new position at the left of the array when it is smaller than the chosen value (the
* "store index" used will then be incremented)
* When all elements have been inspected then the chosen value will be put at the "store index" and the store index will be
* returned.
* The buffer is now partially sorted: below the "store index" all values are lower than the chosen value, beyond the "store index"
* all values are greater than or equal to the chosen value.
* Now if the "store index" is n then we are done. Otherwise we have to "sort" again but only left to the "store index" or only
* right to the "store index".
* Pseudocode for quickselect from wikipedia (pseudocode for function partition() see below):
* function select(list, left, right, n)
* loop
* if left = right
* return list[left]
* pivotIndex := ... // select pivotIndex between left and right
* pivotIndex := partition(list, left, right, pivotIndex)
* if n = pivotIndex
* return list[n]
* else if n < pivotIndex
* right := pivotIndex - 1
* else
* left := pivotIndex + 1
*/
void quickselect_n() {
int left = 0;
int right = beyondBuffer - 1;
while (left < right) {
int pivotIndex = left + random.nextInt(right - left + 1);
int pivotNewIndex = partition(left, right, pivotIndex);
if (pivotNewIndex == n - 1) {
break;
} else if (n - 1 < pivotNewIndex) {
right = pivotNewIndex - 1;
} else {
left = pivotNewIndex + 1;
}
}
}
/**
* Pseudocode for function partition() from wikipedia:
* function partition(list, left, right, pivotIndex)
* pivotValue := list[pivotIndex]
* swap list[pivotIndex] and list[right] // Move pivot to end
* storeIndex := left
* for i from left to right-1
* if list[i] < pivotValue
* swap list[storeIndex] and list[i]
* increment storeIndex
* swap list[right] and list[storeIndex] // Move pivot to its final place
* return storeIndex
*/
private int partition(int left, int right, int pivotIndex) {
// 1. the chosen value is the one at the pivotIndex
E pivotValue = buffer[pivotIndex];
// 2. the buffer elemnt at the pivotIndex is set to the value of the rightmost array element (this way the rightmost element does not need
// to be inspected
buffer[pivotIndex] = buffer[right];
// 3. iterate from left to right and keep the value where it is if it is smaller than the chosen value
int storeIndex = left;
int i = left;
for (; i < right; i++) {
if (comparator.compare(buffer[i], pivotValue) < 0) {
storeIndex++;
} else {
break;
}
}
// 4. continue iterating but now move the value to the left if it is smaller than the chosen value
for (i++; i < right; i++) {
if (comparator.compare(buffer[i], pivotValue) < 0) {
E tmp = buffer[storeIndex];
buffer[storeIndex] = buffer[i];
buffer[i] = tmp;
storeIndex++;
}
}
// 5. move the chosen value to the store index
buffer[right] = buffer[storeIndex];
buffer[storeIndex] = pivotValue;
return storeIndex;
}
@Override
public E next() {
if (hasNext()) {
if (returningFromTheBuffer && bufferIndex < beyondBuffer) {
E tmp = buffer[bufferIndex];
if (++bufferIndex == beyondBuffer) {
bufferIndex = n;
returningFromTheBuffer = false;
}
return tmp;
}
E tmp = toReturn;
toReturn = null;
return tmp;
} else {
throw new NoSuchElementException();
}
}
}
public static void test(int k, int factor) {
int bufferCap = factor * k;
int[] values = new int[bufferCap];
Map<Integer, Integer> map = new TreeMap<>();
Random random = new Random();
int j = 0;
while (j < bufferCap) {
int value = random.nextInt();
Integer oldCount = map.put(value, 1);
if (oldCount != null) {
map.put(value, oldCount + 1);
}
values[j++] = value;
}
int value = 0;
int count = 0;
for (Entry<Integer, Integer> entry : map.entrySet()) {
count += entry.getValue();
if (count >= k) {
value = entry.getKey();
break;
}
}
Iterator<Integer> iterator = new Iterator<Integer>() {
int index = 0;
@Override
public boolean hasNext() {
return (index < bufferCap ? true : false);
}
@Override
public Integer next() {
return values[index++];
}
};
Iterator<Integer> x = skipSorted(iterator, Comparator.naturalOrder(), k);
count = 0;
while (x.hasNext()) {
Integer num = x.next();
count++;
if (num < value) {
throw new RuntimeException();
}
}
if (count != bufferCap - k) {
throw new RuntimeException();
}
}
public static void main(String[] args) {
Random random = new Random();
for (int i = 0; i < 1000; i++) {
int k = 50 + random.nextInt(1000);
int size = 2 + random.nextInt(10);
test(k, size);
}
}