e6fca4d47f661ff16fdc8c2bb7ae5b86c7f347b2
[SubU] /
1 import warnings
2
3 from collections import Counter, defaultdict, deque, abc
4 from collections.abc import Sequence
5 from functools import partial, reduce, wraps
6 from heapq import merge, heapify, heapreplace, heappop
7 from itertools import (
8     chain,
9     compress,
10     count,
11     cycle,
12     dropwhile,
13     groupby,
14     islice,
15     repeat,
16     starmap,
17     takewhile,
18     tee,
19     zip_longest,
20 )
21 from math import exp, factorial, floor, log
22 from queue import Empty, Queue
23 from random import random, randrange, uniform
24 from operator import itemgetter, mul, sub, gt, lt
25 from sys import hexversion, maxsize
26 from time import monotonic
27
28 from .recipes import (
29     consume,
30     flatten,
31     pairwise,
32     powerset,
33     take,
34     unique_everseen,
35 )
36
37 __all__ = [
38     'AbortThread',
39     'adjacent',
40     'always_iterable',
41     'always_reversible',
42     'bucket',
43     'callback_iter',
44     'chunked',
45     'circular_shifts',
46     'collapse',
47     'collate',
48     'consecutive_groups',
49     'consumer',
50     'countable',
51     'count_cycle',
52     'mark_ends',
53     'difference',
54     'distinct_combinations',
55     'distinct_permutations',
56     'distribute',
57     'divide',
58     'exactly_n',
59     'filter_except',
60     'first',
61     'groupby_transform',
62     'ilen',
63     'interleave_longest',
64     'interleave',
65     'intersperse',
66     'islice_extended',
67     'iterate',
68     'ichunked',
69     'is_sorted',
70     'last',
71     'locate',
72     'lstrip',
73     'make_decorator',
74     'map_except',
75     'map_reduce',
76     'nth_or_last',
77     'nth_permutation',
78     'nth_product',
79     'numeric_range',
80     'one',
81     'only',
82     'padded',
83     'partitions',
84     'set_partitions',
85     'peekable',
86     'repeat_last',
87     'replace',
88     'rlocate',
89     'rstrip',
90     'run_length',
91     'sample',
92     'seekable',
93     'SequenceView',
94     'side_effect',
95     'sliced',
96     'sort_together',
97     'split_at',
98     'split_after',
99     'split_before',
100     'split_when',
101     'split_into',
102     'spy',
103     'stagger',
104     'strip',
105     'substrings',
106     'substrings_indexes',
107     'time_limited',
108     'unique_to_each',
109     'unzip',
110     'windowed',
111     'with_iter',
112     'UnequalIterablesError',
113     'zip_equal',
114     'zip_offset',
115     'windowed_complete',
116     'all_unique',
117     'value_chain',
118     'product_index',
119     'combination_index',
120     'permutation_index',
121 ]
122
123 _marker = object()
124
125
126 def chunked(iterable, n, strict=False):
127     """Break *iterable* into lists of length *n*:
128
129         >>> list(chunked([1, 2, 3, 4, 5, 6], 3))
130         [[1, 2, 3], [4, 5, 6]]
131
132     By the default, the last yielded list will have fewer than *n* elements
133     if the length of *iterable* is not divisible by *n*:
134
135         >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3))
136         [[1, 2, 3], [4, 5, 6], [7, 8]]
137
138     To use a fill-in value instead, see the :func:`grouper` recipe.
139
140     If the length of *iterable* is not divisible by *n* and *strict* is
141     ``True``, then ``ValueError`` will be raised before the last
142     list is yielded.
143
144     """
145     iterator = iter(partial(take, n, iter(iterable)), [])
146     if strict:
147
148         def ret():
149             for chunk in iterator:
150                 if len(chunk) != n:
151                     raise ValueError('iterable is not divisible by n.')
152                 yield chunk
153
154         return iter(ret())
155     else:
156         return iterator
157
158
159 def first(iterable, default=_marker):
160     """Return the first item of *iterable*, or *default* if *iterable* is
161     empty.
162
163         >>> first([0, 1, 2, 3])
164         0
165         >>> first([], 'some default')
166         'some default'
167
168     If *default* is not provided and there are no items in the iterable,
169     raise ``ValueError``.
170
171     :func:`first` is useful when you have a generator of expensive-to-retrieve
172     values and want any arbitrary one. It is marginally shorter than
173     ``next(iter(iterable), default)``.
174
175     """
176     try:
177         return next(iter(iterable))
178     except StopIteration as e:
179         if default is _marker:
180             raise ValueError(
181                 'first() was called on an empty iterable, and no '
182                 'default value was provided.'
183             ) from e
184         return default
185
186
187 def last(iterable, default=_marker):
188     """Return the last item of *iterable*, or *default* if *iterable* is
189     empty.
190
191         >>> last([0, 1, 2, 3])
192         3
193         >>> last([], 'some default')
194         'some default'
195
196     If *default* is not provided and there are no items in the iterable,
197     raise ``ValueError``.
198     """
199     try:
200         if isinstance(iterable, Sequence):
201             return iterable[-1]
202         # Work around https://bugs.python.org/issue38525
203         elif hasattr(iterable, '__reversed__') and (hexversion != 0x030800F0):
204             return next(reversed(iterable))
205         else:
206             return deque(iterable, maxlen=1)[-1]
207     except (IndexError, TypeError, StopIteration):
208         if default is _marker:
209             raise ValueError(
210                 'last() was called on an empty iterable, and no default was '
211                 'provided.'
212             )
213         return default
214
215
216 def nth_or_last(iterable, n, default=_marker):
217     """Return the nth or the last item of *iterable*,
218     or *default* if *iterable* is empty.
219
220         >>> nth_or_last([0, 1, 2, 3], 2)
221         2
222         >>> nth_or_last([0, 1], 2)
223         1
224         >>> nth_or_last([], 0, 'some default')
225         'some default'
226
227     If *default* is not provided and there are no items in the iterable,
228     raise ``ValueError``.
229     """
230     return last(islice(iterable, n + 1), default=default)
231
232
233 class peekable:
234     """Wrap an iterator to allow lookahead and prepending elements.
235
236     Call :meth:`peek` on the result to get the value that will be returned
237     by :func:`next`. This won't advance the iterator:
238
239         >>> p = peekable(['a', 'b'])
240         >>> p.peek()
241         'a'
242         >>> next(p)
243         'a'
244
245     Pass :meth:`peek` a default value to return that instead of raising
246     ``StopIteration`` when the iterator is exhausted.
247
248         >>> p = peekable([])
249         >>> p.peek('hi')
250         'hi'
251
252     peekables also offer a :meth:`prepend` method, which "inserts" items
253     at the head of the iterable:
254
255         >>> p = peekable([1, 2, 3])
256         >>> p.prepend(10, 11, 12)
257         >>> next(p)
258         10
259         >>> p.peek()
260         11
261         >>> list(p)
262         [11, 12, 1, 2, 3]
263
264     peekables can be indexed. Index 0 is the item that will be returned by
265     :func:`next`, index 1 is the item after that, and so on:
266     The values up to the given index will be cached.
267
268         >>> p = peekable(['a', 'b', 'c', 'd'])
269         >>> p[0]
270         'a'
271         >>> p[1]
272         'b'
273         >>> next(p)
274         'a'
275
276     Negative indexes are supported, but be aware that they will cache the
277     remaining items in the source iterator, which may require significant
278     storage.
279
280     To check whether a peekable is exhausted, check its truth value:
281
282         >>> p = peekable(['a', 'b'])
283         >>> if p:  # peekable has items
284         ...     list(p)
285         ['a', 'b']
286         >>> if not p:  # peekable is exhausted
287         ...     list(p)
288         []
289
290     """
291
292     def __init__(self, iterable):
293         self._it = iter(iterable)
294         self._cache = deque()
295
296     def __iter__(self):
297         return self
298
299     def __bool__(self):
300         try:
301             self.peek()
302         except StopIteration:
303             return False
304         return True
305
306     def peek(self, default=_marker):
307         """Return the item that will be next returned from ``next()``.
308
309         Return ``default`` if there are no items left. If ``default`` is not
310         provided, raise ``StopIteration``.
311
312         """
313         if not self._cache:
314             try:
315                 self._cache.append(next(self._it))
316             except StopIteration:
317                 if default is _marker:
318                     raise
319                 return default
320         return self._cache[0]
321
322     def prepend(self, *items):
323         """Stack up items to be the next ones returned from ``next()`` or
324         ``self.peek()``. The items will be returned in
325         first in, first out order::
326
327             >>> p = peekable([1, 2, 3])
328             >>> p.prepend(10, 11, 12)
329             >>> next(p)
330             10
331             >>> list(p)
332             [11, 12, 1, 2, 3]
333
334         It is possible, by prepending items, to "resurrect" a peekable that
335         previously raised ``StopIteration``.
336
337             >>> p = peekable([])
338             >>> next(p)
339             Traceback (most recent call last):
340               ...
341             StopIteration
342             >>> p.prepend(1)
343             >>> next(p)
344             1
345             >>> next(p)
346             Traceback (most recent call last):
347               ...
348             StopIteration
349
350         """
351         self._cache.extendleft(reversed(items))
352
353     def __next__(self):
354         if self._cache:
355             return self._cache.popleft()
356
357         return next(self._it)
358
359     def _get_slice(self, index):
360         # Normalize the slice's arguments
361         step = 1 if (index.step is None) else index.step
362         if step > 0:
363             start = 0 if (index.start is None) else index.start
364             stop = maxsize if (index.stop is None) else index.stop
365         elif step < 0:
366             start = -1 if (index.start is None) else index.start
367             stop = (-maxsize - 1) if (index.stop is None) else index.stop
368         else:
369             raise ValueError('slice step cannot be zero')
370
371         # If either the start or stop index is negative, we'll need to cache
372         # the rest of the iterable in order to slice from the right side.
373         if (start < 0) or (stop < 0):
374             self._cache.extend(self._it)
375         # Otherwise we'll need to find the rightmost index and cache to that
376         # point.
377         else:
378             n = min(max(start, stop) + 1, maxsize)
379             cache_len = len(self._cache)
380             if n >= cache_len:
381                 self._cache.extend(islice(self._it, n - cache_len))
382
383         return list(self._cache)[index]
384
385     def __getitem__(self, index):
386         if isinstance(index, slice):
387             return self._get_slice(index)
388
389         cache_len = len(self._cache)
390         if index < 0:
391             self._cache.extend(self._it)
392         elif index >= cache_len:
393             self._cache.extend(islice(self._it, index + 1 - cache_len))
394
395         return self._cache[index]
396
397
398 def collate(*iterables, **kwargs):
399     """Return a sorted merge of the items from each of several already-sorted
400     *iterables*.
401
402         >>> list(collate('ACDZ', 'AZ', 'JKL'))
403         ['A', 'A', 'C', 'D', 'J', 'K', 'L', 'Z', 'Z']
404
405     Works lazily, keeping only the next value from each iterable in memory. Use
406     :func:`collate` to, for example, perform a n-way mergesort of items that
407     don't fit in memory.
408
409     If a *key* function is specified, the iterables will be sorted according
410     to its result:
411
412         >>> key = lambda s: int(s)  # Sort by numeric value, not by string
413         >>> list(collate(['1', '10'], ['2', '11'], key=key))
414         ['1', '2', '10', '11']
415
416
417     If the *iterables* are sorted in descending order, set *reverse* to
418     ``True``:
419
420         >>> list(collate([5, 3, 1], [4, 2, 0], reverse=True))
421         [5, 4, 3, 2, 1, 0]
422
423     If the elements of the passed-in iterables are out of order, you might get
424     unexpected results.
425
426     On Python 3.5+, this function is an alias for :func:`heapq.merge`.
427
428     """
429     warnings.warn(
430         "collate is no longer part of more_itertools, use heapq.merge",
431         DeprecationWarning,
432     )
433     return merge(*iterables, **kwargs)
434
435
436 def consumer(func):
437     """Decorator that automatically advances a PEP-342-style "reverse iterator"
438     to its first yield point so you don't have to call ``next()`` on it
439     manually.
440
441         >>> @consumer
442         ... def tally():
443         ...     i = 0
444         ...     while True:
445         ...         print('Thing number %s is %s.' % (i, (yield)))
446         ...         i += 1
447         ...
448         >>> t = tally()
449         >>> t.send('red')
450         Thing number 0 is red.
451         >>> t.send('fish')
452         Thing number 1 is fish.
453
454     Without the decorator, you would have to call ``next(t)`` before
455     ``t.send()`` could be used.
456
457     """
458
459     @wraps(func)
460     def wrapper(*args, **kwargs):
461         gen = func(*args, **kwargs)
462         next(gen)
463         return gen
464
465     return wrapper
466
467
468 def ilen(iterable):
469     """Return the number of items in *iterable*.
470
471         >>> ilen(x for x in range(1000000) if x % 3 == 0)
472         333334
473
474     This consumes the iterable, so handle with care.
475
476     """
477     # This approach was selected because benchmarks showed it's likely the
478     # fastest of the known implementations at the time of writing.
479     # See GitHub tracker: #236, #230.
480     counter = count()
481     deque(zip(iterable, counter), maxlen=0)
482     return next(counter)
483
484
485 def iterate(func, start):
486     """Return ``start``, ``func(start)``, ``func(func(start))``, ...
487
488     >>> from itertools import islice
489     >>> list(islice(iterate(lambda x: 2*x, 1), 10))
490     [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
491
492     """
493     while True:
494         yield start
495         start = func(start)
496
497
498 def with_iter(context_manager):
499     """Wrap an iterable in a ``with`` statement, so it closes once exhausted.
500
501     For example, this will close the file when the iterator is exhausted::
502
503         upper_lines = (line.upper() for line in with_iter(open('foo')))
504
505     Any context manager which returns an iterable is a candidate for
506     ``with_iter``.
507
508     """
509     with context_manager as iterable:
510         yield from iterable
511
512
513 def one(iterable, too_short=None, too_long=None):
514     """Return the first item from *iterable*, which is expected to contain only
515     that item. Raise an exception if *iterable* is empty or has more than one
516     item.
517
518     :func:`one` is useful for ensuring that an iterable contains only one item.
519     For example, it can be used to retrieve the result of a database query
520     that is expected to return a single row.
521
522     If *iterable* is empty, ``ValueError`` will be raised. You may specify a
523     different exception with the *too_short* keyword:
524
525         >>> it = []
526         >>> one(it)  # doctest: +IGNORE_EXCEPTION_DETAIL
527         Traceback (most recent call last):
528         ...
529         ValueError: too many items in iterable (expected 1)'
530         >>> too_short = IndexError('too few items')
531         >>> one(it, too_short=too_short)  # doctest: +IGNORE_EXCEPTION_DETAIL
532         Traceback (most recent call last):
533         ...
534         IndexError: too few items
535
536     Similarly, if *iterable* contains more than one item, ``ValueError`` will
537     be raised. You may specify a different exception with the *too_long*
538     keyword:
539
540         >>> it = ['too', 'many']
541         >>> one(it)  # doctest: +IGNORE_EXCEPTION_DETAIL
542         Traceback (most recent call last):
543         ...
544         ValueError: Expected exactly one item in iterable, but got 'too',
545         'many', and perhaps more.
546         >>> too_long = RuntimeError
547         >>> one(it, too_long=too_long)  # doctest: +IGNORE_EXCEPTION_DETAIL
548         Traceback (most recent call last):
549         ...
550         RuntimeError
551
552     Note that :func:`one` attempts to advance *iterable* twice to ensure there
553     is only one item. See :func:`spy` or :func:`peekable` to check iterable
554     contents less destructively.
555
556     """
557     it = iter(iterable)
558
559     try:
560         first_value = next(it)
561     except StopIteration as e:
562         raise (
563             too_short or ValueError('too few items in iterable (expected 1)')
564         ) from e
565
566     try:
567         second_value = next(it)
568     except StopIteration:
569         pass
570     else:
571         msg = (
572             'Expected exactly one item in iterable, but got {!r}, {!r}, '
573             'and perhaps more.'.format(first_value, second_value)
574         )
575         raise too_long or ValueError(msg)
576
577     return first_value
578
579
580 def distinct_permutations(iterable, r=None):
581     """Yield successive distinct permutations of the elements in *iterable*.
582
583         >>> sorted(distinct_permutations([1, 0, 1]))
584         [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
585
586     Equivalent to ``set(permutations(iterable))``, except duplicates are not
587     generated and thrown away. For larger input sequences this is much more
588     efficient.
589
590     Duplicate permutations arise when there are duplicated elements in the
591     input iterable. The number of items returned is
592     `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of
593     items input, and each `x_i` is the count of a distinct item in the input
594     sequence.
595
596     If *r* is given, only the *r*-length permutations are yielded.
597
598         >>> sorted(distinct_permutations([1, 0, 1], r=2))
599         [(0, 1), (1, 0), (1, 1)]
600         >>> sorted(distinct_permutations(range(3), r=2))
601         [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
602
603     """
604     # Algorithm: https://w.wiki/Qai
605     def _full(A):
606         while True:
607             # Yield the permutation we have
608             yield tuple(A)
609
610             # Find the largest index i such that A[i] < A[i + 1]
611             for i in range(size - 2, -1, -1):
612                 if A[i] < A[i + 1]:
613                     break
614             #  If no such index exists, this permutation is the last one
615             else:
616                 return
617
618             # Find the largest index j greater than j such that A[i] < A[j]
619             for j in range(size - 1, i, -1):
620                 if A[i] < A[j]:
621                     break
622
623             # Swap the value of A[i] with that of A[j], then reverse the
624             # sequence from A[i + 1] to form the new permutation
625             A[i], A[j] = A[j], A[i]
626             A[i + 1 :] = A[: i - size : -1]  # A[i + 1:][::-1]
627
628     # Algorithm: modified from the above
629     def _partial(A, r):
630         # Split A into the first r items and the last r items
631         head, tail = A[:r], A[r:]
632         right_head_indexes = range(r - 1, -1, -1)
633         left_tail_indexes = range(len(tail))
634
635         while True:
636             # Yield the permutation we have
637             yield tuple(head)
638
639             # Starting from the right, find the first index of the head with
640             # value smaller than the maximum value of the tail - call it i.
641             pivot = tail[-1]
642             for i in right_head_indexes:
643                 if head[i] < pivot:
644                     break
645                 pivot = head[i]
646             else:
647                 return
648
649             # Starting from the left, find the first value of the tail
650             # with a value greater than head[i] and swap.
651             for j in left_tail_indexes:
652                 if tail[j] > head[i]:
653                     head[i], tail[j] = tail[j], head[i]
654                     break
655             # If we didn't find one, start from the right and find the first
656             # index of the head with a value greater than head[i] and swap.
657             else:
658                 for j in right_head_indexes:
659                     if head[j] > head[i]:
660                         head[i], head[j] = head[j], head[i]
661                         break
662
663             # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)]
664             tail += head[: i - r : -1]  # head[i + 1:][::-1]
665             i += 1
666             head[i:], tail[:] = tail[: r - i], tail[r - i :]
667
668     items = sorted(iterable)
669
670     size = len(items)
671     if r is None:
672         r = size
673
674     if 0 < r <= size:
675         return _full(items) if (r == size) else _partial(items, r)
676
677     return iter(() if r else ((),))
678
679
680 def intersperse(e, iterable, n=1):
681     """Intersperse filler element *e* among the items in *iterable*, leaving
682     *n* items between each filler element.
683
684         >>> list(intersperse('!', [1, 2, 3, 4, 5]))
685         [1, '!', 2, '!', 3, '!', 4, '!', 5]
686
687         >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2))
688         [1, 2, None, 3, 4, None, 5]
689
690     """
691     if n == 0:
692         raise ValueError('n must be > 0')
693     elif n == 1:
694         # interleave(repeat(e), iterable) -> e, x_0, e, e, x_1, e, x_2...
695         # islice(..., 1, None) -> x_0, e, e, x_1, e, x_2...
696         return islice(interleave(repeat(e), iterable), 1, None)
697     else:
698         # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]...
699         # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]...
700         # flatten(...) -> x_0, x_1, e, x_2, x_3...
701         filler = repeat([e])
702         chunks = chunked(iterable, n)
703         return flatten(islice(interleave(filler, chunks), 1, None))
704
705
706 def unique_to_each(*iterables):
707     """Return the elements from each of the input iterables that aren't in the
708     other input iterables.
709
710     For example, suppose you have a set of packages, each with a set of
711     dependencies::
712
713         {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}}
714
715     If you remove one package, which dependencies can also be removed?
716
717     If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not
718     associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for
719     ``pkg_2``, and ``D`` is only needed for ``pkg_3``::
720
721         >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'})
722         [['A'], ['C'], ['D']]
723
724     If there are duplicates in one input iterable that aren't in the others
725     they will be duplicated in the output. Input order is preserved::
726
727         >>> unique_to_each("mississippi", "missouri")
728         [['p', 'p'], ['o', 'u', 'r']]
729
730     It is assumed that the elements of each iterable are hashable.
731
732     """
733     pool = [list(it) for it in iterables]
734     counts = Counter(chain.from_iterable(map(set, pool)))
735     uniques = {element for element in counts if counts[element] == 1}
736     return [list(filter(uniques.__contains__, it)) for it in pool]
737
738
739 def windowed(seq, n, fillvalue=None, step=1):
740     """Return a sliding window of width *n* over the given iterable.
741
742         >>> all_windows = windowed([1, 2, 3, 4, 5], 3)
743         >>> list(all_windows)
744         [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
745
746     When the window is larger than the iterable, *fillvalue* is used in place
747     of missing values:
748
749         >>> list(windowed([1, 2, 3], 4))
750         [(1, 2, 3, None)]
751
752     Each window will advance in increments of *step*:
753
754         >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2))
755         [(1, 2, 3), (3, 4, 5), (5, 6, '!')]
756
757     To slide into the iterable's items, use :func:`chain` to add filler items
758     to the left:
759
760         >>> iterable = [1, 2, 3, 4]
761         >>> n = 3
762         >>> padding = [None] * (n - 1)
763         >>> list(windowed(chain(padding, iterable), 3))
764         [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)]
765     """
766     if n < 0:
767         raise ValueError('n must be >= 0')
768     if n == 0:
769         yield tuple()
770         return
771     if step < 1:
772         raise ValueError('step must be >= 1')
773
774     window = deque(maxlen=n)
775     i = n
776     for _ in map(window.append, seq):
777         i -= 1
778         if not i:
779             i = step
780             yield tuple(window)
781
782     size = len(window)
783     if size < n:
784         yield tuple(chain(window, repeat(fillvalue, n - size)))
785     elif 0 < i < min(step, n):
786         window += (fillvalue,) * i
787         yield tuple(window)
788
789
790 def substrings(iterable):
791     """Yield all of the substrings of *iterable*.
792
793         >>> [''.join(s) for s in substrings('more')]
794         ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more']
795
796     Note that non-string iterables can also be subdivided.
797
798         >>> list(substrings([0, 1, 2]))
799         [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)]
800
801     """
802     # The length-1 substrings
803     seq = []
804     for item in iter(iterable):
805         seq.append(item)
806         yield (item,)
807     seq = tuple(seq)
808     item_count = len(seq)
809
810     # And the rest
811     for n in range(2, item_count + 1):
812         for i in range(item_count - n + 1):
813             yield seq[i : i + n]
814
815
816 def substrings_indexes(seq, reverse=False):
817     """Yield all substrings and their positions in *seq*
818
819     The items yielded will be a tuple of the form ``(substr, i, j)``, where
820     ``substr == seq[i:j]``.
821
822     This function only works for iterables that support slicing, such as
823     ``str`` objects.
824
825     >>> for item in substrings_indexes('more'):
826     ...    print(item)
827     ('m', 0, 1)
828     ('o', 1, 2)
829     ('r', 2, 3)
830     ('e', 3, 4)
831     ('mo', 0, 2)
832     ('or', 1, 3)
833     ('re', 2, 4)
834     ('mor', 0, 3)
835     ('ore', 1, 4)
836     ('more', 0, 4)
837
838     Set *reverse* to ``True`` to yield the same items in the opposite order.
839
840
841     """
842     r = range(1, len(seq) + 1)
843     if reverse:
844         r = reversed(r)
845     return (
846         (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1)
847     )
848
849
850 class bucket:
851     """Wrap *iterable* and return an object that buckets it iterable into
852     child iterables based on a *key* function.
853
854         >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
855         >>> s = bucket(iterable, key=lambda x: x[0])  # Bucket by 1st character
856         >>> sorted(list(s))  # Get the keys
857         ['a', 'b', 'c']
858         >>> a_iterable = s['a']
859         >>> next(a_iterable)
860         'a1'
861         >>> next(a_iterable)
862         'a2'
863         >>> list(s['b'])
864         ['b1', 'b2', 'b3']
865
866     The original iterable will be advanced and its items will be cached until
867     they are used by the child iterables. This may require significant storage.
868
869     By default, attempting to select a bucket to which no items belong  will
870     exhaust the iterable and cache all values.
871     If you specify a *validator* function, selected buckets will instead be
872     checked against it.
873
874         >>> from itertools import count
875         >>> it = count(1, 2)  # Infinite sequence of odd numbers
876         >>> key = lambda x: x % 10  # Bucket by last digit
877         >>> validator = lambda x: x in {1, 3, 5, 7, 9}  # Odd digits only
878         >>> s = bucket(it, key=key, validator=validator)
879         >>> 2 in s
880         False
881         >>> list(s[2])
882         []
883
884     """
885
886     def __init__(self, iterable, key, validator=None):
887         self._it = iter(iterable)
888         self._key = key
889         self._cache = defaultdict(deque)
890         self._validator = validator or (lambda x: True)
891
892     def __contains__(self, value):
893         if not self._validator(value):
894             return False
895
896         try:
897             item = next(self[value])
898         except StopIteration:
899             return False
900         else:
901             self._cache[value].appendleft(item)
902
903         return True
904
905     def _get_values(self, value):
906         """
907         Helper to yield items from the parent iterator that match *value*.
908         Items that don't match are stored in the local cache as they
909         are encountered.
910         """
911         while True:
912             # If we've cached some items that match the target value, emit
913             # the first one and evict it from the cache.
914             if self._cache[value]:
915                 yield self._cache[value].popleft()
916             # Otherwise we need to advance the parent iterator to search for
917             # a matching item, caching the rest.
918             else:
919                 while True:
920                     try:
921                         item = next(self._it)
922                     except StopIteration:
923                         return
924                     item_value = self._key(item)
925                     if item_value == value:
926                         yield item
927                         break
928                     elif self._validator(item_value):
929                         self._cache[item_value].append(item)
930
931     def __iter__(self):
932         for item in self._it:
933             item_value = self._key(item)
934             if self._validator(item_value):
935                 self._cache[item_value].append(item)
936
937         yield from self._cache.keys()
938
939     def __getitem__(self, value):
940         if not self._validator(value):
941             return iter(())
942
943         return self._get_values(value)
944
945
946 def spy(iterable, n=1):
947     """Return a 2-tuple with a list containing the first *n* elements of
948     *iterable*, and an iterator with the same items as *iterable*.
949     This allows you to "look ahead" at the items in the iterable without
950     advancing it.
951
952     There is one item in the list by default:
953
954         >>> iterable = 'abcdefg'
955         >>> head, iterable = spy(iterable)
956         >>> head
957         ['a']
958         >>> list(iterable)
959         ['a', 'b', 'c', 'd', 'e', 'f', 'g']
960
961     You may use unpacking to retrieve items instead of lists:
962
963         >>> (head,), iterable = spy('abcdefg')
964         >>> head
965         'a'
966         >>> (first, second), iterable = spy('abcdefg', 2)
967         >>> first
968         'a'
969         >>> second
970         'b'
971
972     The number of items requested can be larger than the number of items in
973     the iterable:
974
975         >>> iterable = [1, 2, 3, 4, 5]
976         >>> head, iterable = spy(iterable, 10)
977         >>> head
978         [1, 2, 3, 4, 5]
979         >>> list(iterable)
980         [1, 2, 3, 4, 5]
981
982     """
983     it = iter(iterable)
984     head = take(n, it)
985
986     return head.copy(), chain(head, it)
987
988
989 def interleave(*iterables):
990     """Return a new iterable yielding from each iterable in turn,
991     until the shortest is exhausted.
992
993         >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8]))
994         [1, 4, 6, 2, 5, 7]
995
996     For a version that doesn't terminate after the shortest iterable is
997     exhausted, see :func:`interleave_longest`.
998
999     """
1000     return chain.from_iterable(zip(*iterables))
1001
1002
1003 def interleave_longest(*iterables):
1004     """Return a new iterable yielding from each iterable in turn,
1005     skipping any that are exhausted.
1006
1007         >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8]))
1008         [1, 4, 6, 2, 5, 7, 3, 8]
1009
1010     This function produces the same output as :func:`roundrobin`, but may
1011     perform better for some inputs (in particular when the number of iterables
1012     is large).
1013
1014     """
1015     i = chain.from_iterable(zip_longest(*iterables, fillvalue=_marker))
1016     return (x for x in i if x is not _marker)
1017
1018
1019 def collapse(iterable, base_type=None, levels=None):
1020     """Flatten an iterable with multiple levels of nesting (e.g., a list of
1021     lists of tuples) into non-iterable types.
1022
1023         >>> iterable = [(1, 2), ([3, 4], [[5], [6]])]
1024         >>> list(collapse(iterable))
1025         [1, 2, 3, 4, 5, 6]
1026
1027     Binary and text strings are not considered iterable and
1028     will not be collapsed.
1029
1030     To avoid collapsing other types, specify *base_type*:
1031
1032         >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']]
1033         >>> list(collapse(iterable, base_type=tuple))
1034         ['ab', ('cd', 'ef'), 'gh', 'ij']
1035
1036     Specify *levels* to stop flattening after a certain level:
1037
1038     >>> iterable = [('a', ['b']), ('c', ['d'])]
1039     >>> list(collapse(iterable))  # Fully flattened
1040     ['a', 'b', 'c', 'd']
1041     >>> list(collapse(iterable, levels=1))  # Only one level flattened
1042     ['a', ['b'], 'c', ['d']]
1043
1044     """
1045
1046     def walk(node, level):
1047         if (
1048             ((levels is not None) and (level > levels))
1049             or isinstance(node, (str, bytes))
1050             or ((base_type is not None) and isinstance(node, base_type))
1051         ):
1052             yield node
1053             return
1054
1055         try:
1056             tree = iter(node)
1057         except TypeError:
1058             yield node
1059             return
1060         else:
1061             for child in tree:
1062                 yield from walk(child, level + 1)
1063
1064     yield from walk(iterable, 0)
1065
1066
1067 def side_effect(func, iterable, chunk_size=None, before=None, after=None):
1068     """Invoke *func* on each item in *iterable* (or on each *chunk_size* group
1069     of items) before yielding the item.
1070
1071     `func` must be a function that takes a single argument. Its return value
1072     will be discarded.
1073
1074     *before* and *after* are optional functions that take no arguments. They
1075     will be executed before iteration starts and after it ends, respectively.
1076
1077     `side_effect` can be used for logging, updating progress bars, or anything
1078     that is not functionally "pure."
1079
1080     Emitting a status message:
1081
1082         >>> from more_itertools import consume
1083         >>> func = lambda item: print('Received {}'.format(item))
1084         >>> consume(side_effect(func, range(2)))
1085         Received 0
1086         Received 1
1087
1088     Operating on chunks of items:
1089
1090         >>> pair_sums = []
1091         >>> func = lambda chunk: pair_sums.append(sum(chunk))
1092         >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2))
1093         [0, 1, 2, 3, 4, 5]
1094         >>> list(pair_sums)
1095         [1, 5, 9]
1096
1097     Writing to a file-like object:
1098
1099         >>> from io import StringIO
1100         >>> from more_itertools import consume
1101         >>> f = StringIO()
1102         >>> func = lambda x: print(x, file=f)
1103         >>> before = lambda: print(u'HEADER', file=f)
1104         >>> after = f.close
1105         >>> it = [u'a', u'b', u'c']
1106         >>> consume(side_effect(func, it, before=before, after=after))
1107         >>> f.closed
1108         True
1109
1110     """
1111     try:
1112         if before is not None:
1113             before()
1114
1115         if chunk_size is None:
1116             for item in iterable:
1117                 func(item)
1118                 yield item
1119         else:
1120             for chunk in chunked(iterable, chunk_size):
1121                 func(chunk)
1122                 yield from chunk
1123     finally:
1124         if after is not None:
1125             after()
1126
1127
1128 def sliced(seq, n, strict=False):
1129     """Yield slices of length *n* from the sequence *seq*.
1130
1131     >>> list(sliced((1, 2, 3, 4, 5, 6), 3))
1132     [(1, 2, 3), (4, 5, 6)]
1133
1134     By the default, the last yielded slice will have fewer than *n* elements
1135     if the length of *seq* is not divisible by *n*:
1136
1137     >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3))
1138     [(1, 2, 3), (4, 5, 6), (7, 8)]
1139
1140     If the length of *seq* is not divisible by *n* and *strict* is
1141     ``True``, then ``ValueError`` will be raised before the last
1142     slice is yielded.
1143
1144     This function will only work for iterables that support slicing.
1145     For non-sliceable iterables, see :func:`chunked`.
1146
1147     """
1148     iterator = takewhile(len, (seq[i : i + n] for i in count(0, n)))
1149     if strict:
1150
1151         def ret():
1152             for _slice in iterator:
1153                 if len(_slice) != n:
1154                     raise ValueError("seq is not divisible by n.")
1155                 yield _slice
1156
1157         return iter(ret())
1158     else:
1159         return iterator
1160
1161
1162 def split_at(iterable, pred, maxsplit=-1, keep_separator=False):
1163     """Yield lists of items from *iterable*, where each list is delimited by
1164     an item where callable *pred* returns ``True``.
1165
1166         >>> list(split_at('abcdcba', lambda x: x == 'b'))
1167         [['a'], ['c', 'd', 'c'], ['a']]
1168
1169         >>> list(split_at(range(10), lambda n: n % 2 == 1))
1170         [[0], [2], [4], [6], [8], []]
1171
1172     At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1173     then there is no limit on the number of splits:
1174
1175         >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
1176         [[0], [2], [4, 5, 6, 7, 8, 9]]
1177
1178     By default, the delimiting items are not included in the output.
1179     The include them, set *keep_separator* to ``True``.
1180
1181         >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
1182         [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]
1183
1184     """
1185     if maxsplit == 0:
1186         yield list(iterable)
1187         return
1188
1189     buf = []
1190     it = iter(iterable)
1191     for item in it:
1192         if pred(item):
1193             yield buf
1194             if keep_separator:
1195                 yield [item]
1196             if maxsplit == 1:
1197                 yield list(it)
1198                 return
1199             buf = []
1200             maxsplit -= 1
1201         else:
1202             buf.append(item)
1203     yield buf
1204
1205
1206 def split_before(iterable, pred, maxsplit=-1):
1207     """Yield lists of items from *iterable*, where each list ends just before
1208     an item for which callable *pred* returns ``True``:
1209
1210         >>> list(split_before('OneTwo', lambda s: s.isupper()))
1211         [['O', 'n', 'e'], ['T', 'w', 'o']]
1212
1213         >>> list(split_before(range(10), lambda n: n % 3 == 0))
1214         [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
1215
1216     At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1217     then there is no limit on the number of splits:
1218
1219         >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
1220         [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]
1221     """
1222     if maxsplit == 0:
1223         yield list(iterable)
1224         return
1225
1226     buf = []
1227     it = iter(iterable)
1228     for item in it:
1229         if pred(item) and buf:
1230             yield buf
1231             if maxsplit == 1:
1232                 yield [item] + list(it)
1233                 return
1234             buf = []
1235             maxsplit -= 1
1236         buf.append(item)
1237     if buf:
1238         yield buf
1239
1240
1241 def split_after(iterable, pred, maxsplit=-1):
1242     """Yield lists of items from *iterable*, where each list ends with an
1243     item where callable *pred* returns ``True``:
1244
1245         >>> list(split_after('one1two2', lambda s: s.isdigit()))
1246         [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']]
1247
1248         >>> list(split_after(range(10), lambda n: n % 3 == 0))
1249         [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]
1250
1251     At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1252     then there is no limit on the number of splits:
1253
1254         >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
1255         [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]
1256
1257     """
1258     if maxsplit == 0:
1259         yield list(iterable)
1260         return
1261
1262     buf = []
1263     it = iter(iterable)
1264     for item in it:
1265         buf.append(item)
1266         if pred(item) and buf:
1267             yield buf
1268             if maxsplit == 1:
1269                 yield list(it)
1270                 return
1271             buf = []
1272             maxsplit -= 1
1273     if buf:
1274         yield buf
1275
1276
1277 def split_when(iterable, pred, maxsplit=-1):
1278     """Split *iterable* into pieces based on the output of *pred*.
1279     *pred* should be a function that takes successive pairs of items and
1280     returns ``True`` if the iterable should be split in between them.
1281
1282     For example, to find runs of increasing numbers, split the iterable when
1283     element ``i`` is larger than element ``i + 1``:
1284
1285         >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y))
1286         [[1, 2, 3, 3], [2, 5], [2, 4], [2]]
1287
1288     At most *maxsplit* splits are done. If *maxsplit* is not specified or -1,
1289     then there is no limit on the number of splits:
1290
1291         >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
1292         ...                 lambda x, y: x > y, maxsplit=2))
1293         [[1, 2, 3, 3], [2, 5], [2, 4, 2]]
1294
1295     """
1296     if maxsplit == 0:
1297         yield list(iterable)
1298         return
1299
1300     it = iter(iterable)
1301     try:
1302         cur_item = next(it)
1303     except StopIteration:
1304         return
1305
1306     buf = [cur_item]
1307     for next_item in it:
1308         if pred(cur_item, next_item):
1309             yield buf
1310             if maxsplit == 1:
1311                 yield [next_item] + list(it)
1312                 return
1313             buf = []
1314             maxsplit -= 1
1315
1316         buf.append(next_item)
1317         cur_item = next_item
1318
1319     yield buf
1320
1321
1322 def split_into(iterable, sizes):
1323     """Yield a list of sequential items from *iterable* of length 'n' for each
1324     integer 'n' in *sizes*.
1325
1326         >>> list(split_into([1,2,3,4,5,6], [1,2,3]))
1327         [[1], [2, 3], [4, 5, 6]]
1328
1329     If the sum of *sizes* is smaller than the length of *iterable*, then the
1330     remaining items of *iterable* will not be returned.
1331
1332         >>> list(split_into([1,2,3,4,5,6], [2,3]))
1333         [[1, 2], [3, 4, 5]]
1334
1335     If the sum of *sizes* is larger than the length of *iterable*, fewer items
1336     will be returned in the iteration that overruns *iterable* and further
1337     lists will be empty:
1338
1339         >>> list(split_into([1,2,3,4], [1,2,3,4]))
1340         [[1], [2, 3], [4], []]
1341
1342     When a ``None`` object is encountered in *sizes*, the returned list will
1343     contain items up to the end of *iterable* the same way that itertools.slice
1344     does:
1345
1346         >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None]))
1347         [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]]
1348
1349     :func:`split_into` can be useful for grouping a series of items where the
1350     sizes of the groups are not uniform. An example would be where in a row
1351     from a table, multiple columns represent elements of the same feature
1352     (e.g. a point represented by x,y,z) but, the format is not the same for
1353     all columns.
1354     """
1355     # convert the iterable argument into an iterator so its contents can
1356     # be consumed by islice in case it is a generator
1357     it = iter(iterable)
1358
1359     for size in sizes:
1360         if size is None:
1361             yield list(it)
1362             return
1363         else:
1364             yield list(islice(it, size))
1365
1366
1367 def padded(iterable, fillvalue=None, n=None, next_multiple=False):
1368     """Yield the elements from *iterable*, followed by *fillvalue*, such that
1369     at least *n* items are emitted.
1370
1371         >>> list(padded([1, 2, 3], '?', 5))
1372         [1, 2, 3, '?', '?']
1373
1374     If *next_multiple* is ``True``, *fillvalue* will be emitted until the
1375     number of items emitted is a multiple of *n*::
1376
1377         >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True))
1378         [1, 2, 3, 4, None, None]
1379
1380     If *n* is ``None``, *fillvalue* will be emitted indefinitely.
1381
1382     """
1383     it = iter(iterable)
1384     if n is None:
1385         yield from chain(it, repeat(fillvalue))
1386     elif n < 1:
1387         raise ValueError('n must be at least 1')
1388     else:
1389         item_count = 0
1390         for item in it:
1391             yield item
1392             item_count += 1
1393
1394         remaining = (n - item_count) % n if next_multiple else n - item_count
1395         for _ in range(remaining):
1396             yield fillvalue
1397
1398
1399 def repeat_last(iterable, default=None):
1400     """After the *iterable* is exhausted, keep yielding its last element.
1401
1402         >>> list(islice(repeat_last(range(3)), 5))
1403         [0, 1, 2, 2, 2]
1404
1405     If the iterable is empty, yield *default* forever::
1406
1407         >>> list(islice(repeat_last(range(0), 42), 5))
1408         [42, 42, 42, 42, 42]
1409
1410     """
1411     item = _marker
1412     for item in iterable:
1413         yield item
1414     final = default if item is _marker else item
1415     yield from repeat(final)
1416
1417
1418 def distribute(n, iterable):
1419     """Distribute the items from *iterable* among *n* smaller iterables.
1420
1421         >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
1422         >>> list(group_1)
1423         [1, 3, 5]
1424         >>> list(group_2)
1425         [2, 4, 6]
1426
1427     If the length of *iterable* is not evenly divisible by *n*, then the
1428     length of the returned iterables will not be identical:
1429
1430         >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
1431         >>> [list(c) for c in children]
1432         [[1, 4, 7], [2, 5], [3, 6]]
1433
1434     If the length of *iterable* is smaller than *n*, then the last returned
1435     iterables will be empty:
1436
1437         >>> children = distribute(5, [1, 2, 3])
1438         >>> [list(c) for c in children]
1439         [[1], [2], [3], [], []]
1440
1441     This function uses :func:`itertools.tee` and may require significant
1442     storage. If you need the order items in the smaller iterables to match the
1443     original iterable, see :func:`divide`.
1444
1445     """
1446     if n < 1:
1447         raise ValueError('n must be at least 1')
1448
1449     children = tee(iterable, n)
1450     return [islice(it, index, None, n) for index, it in enumerate(children)]
1451
1452
1453 def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
1454     """Yield tuples whose elements are offset from *iterable*.
1455     The amount by which the `i`-th item in each tuple is offset is given by
1456     the `i`-th item in *offsets*.
1457
1458         >>> list(stagger([0, 1, 2, 3]))
1459         [(None, 0, 1), (0, 1, 2), (1, 2, 3)]
1460         >>> list(stagger(range(8), offsets=(0, 2, 4)))
1461         [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]
1462
1463     By default, the sequence will end when the final element of a tuple is the
1464     last item in the iterable. To continue until the first element of a tuple
1465     is the last item in the iterable, set *longest* to ``True``::
1466
1467         >>> list(stagger([0, 1, 2, 3], longest=True))
1468         [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)]
1469
1470     By default, ``None`` will be used to replace offsets beyond the end of the
1471     sequence. Specify *fillvalue* to use some other value.
1472
1473     """
1474     children = tee(iterable, len(offsets))
1475
1476     return zip_offset(
1477         *children, offsets=offsets, longest=longest, fillvalue=fillvalue
1478     )
1479
1480
1481 class UnequalIterablesError(ValueError):
1482     def __init__(self, details=None):
1483         msg = 'Iterables have different lengths'
1484         if details is not None:
1485             msg += (': index 0 has length {}; index {} has length {}').format(
1486                 *details
1487             )
1488
1489         super().__init__(msg)
1490
1491
1492 def _zip_equal_generator(iterables):
1493     for combo in zip_longest(*iterables, fillvalue=_marker):
1494         for val in combo:
1495             if val is _marker:
1496                 raise UnequalIterablesError()
1497         yield combo
1498
1499
1500 def zip_equal(*iterables):
1501     """``zip`` the input *iterables* together, but raise
1502     ``UnequalIterablesError`` if they aren't all the same length.
1503
1504         >>> it_1 = range(3)
1505         >>> it_2 = iter('abc')
1506         >>> list(zip_equal(it_1, it_2))
1507         [(0, 'a'), (1, 'b'), (2, 'c')]
1508
1509         >>> it_1 = range(3)
1510         >>> it_2 = iter('abcd')
1511         >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
1512         Traceback (most recent call last):
1513         ...
1514         more_itertools.more.UnequalIterablesError: Iterables have different
1515         lengths
1516
1517     """
1518     if hexversion >= 0x30A00A6:
1519         warnings.warn(
1520             (
1521                 'zip_equal will be removed in a future version of '
1522                 'more-itertools. Use the builtin zip function with '
1523                 'strict=True instead.'
1524             ),
1525             DeprecationWarning,
1526         )
1527     # Check whether the iterables are all the same size.
1528     try:
1529         first_size = len(iterables[0])
1530         for i, it in enumerate(iterables[1:], 1):
1531             size = len(it)
1532             if size != first_size:
1533                 break
1534         else:
1535             # If we didn't break out, we can use the built-in zip.
1536             return zip(*iterables)
1537
1538         # If we did break out, there was a mismatch.
1539         raise UnequalIterablesError(details=(first_size, i, size))
1540     # If any one of the iterables didn't have a length, start reading
1541     # them until one runs out.
1542     except TypeError:
1543         return _zip_equal_generator(iterables)
1544
1545
1546 def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
1547     """``zip`` the input *iterables* together, but offset the `i`-th iterable
1548     by the `i`-th item in *offsets*.
1549
1550         >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1)))
1551         [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')]
1552
1553     This can be used as a lightweight alternative to SciPy or pandas to analyze
1554     data sets in which some series have a lead or lag relationship.
1555
1556     By default, the sequence will end when the shortest iterable is exhausted.
1557     To continue until the longest iterable is exhausted, set *longest* to
1558     ``True``.
1559
1560         >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True))
1561         [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')]
1562
1563     By default, ``None`` will be used to replace offsets beyond the end of the
1564     sequence. Specify *fillvalue* to use some other value.
1565
1566     """
1567     if len(iterables) != len(offsets):
1568         raise ValueError("Number of iterables and offsets didn't match")
1569
1570     staggered = []
1571     for it, n in zip(iterables, offsets):
1572         if n < 0:
1573             staggered.append(chain(repeat(fillvalue, -n), it))
1574         elif n > 0:
1575             staggered.append(islice(it, n, None))
1576         else:
1577             staggered.append(it)
1578
1579     if longest:
1580         return zip_longest(*staggered, fillvalue=fillvalue)
1581
1582     return zip(*staggered)
1583
1584
1585 def sort_together(iterables, key_list=(0,), key=None, reverse=False):
1586     """Return the input iterables sorted together, with *key_list* as the
1587     priority for sorting. All iterables are trimmed to the length of the
1588     shortest one.
1589
1590     This can be used like the sorting function in a spreadsheet. If each
1591     iterable represents a column of data, the key list determines which
1592     columns are used for sorting.
1593
1594     By default, all iterables are sorted using the ``0``-th iterable::
1595
1596         >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')]
1597         >>> sort_together(iterables)
1598         [(1, 2, 3, 4), ('d', 'c', 'b', 'a')]
1599
1600     Set a different key list to sort according to another iterable.
1601     Specifying multiple keys dictates how ties are broken::
1602
1603         >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')]
1604         >>> sort_together(iterables, key_list=(1, 2))
1605         [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')]
1606
1607     To sort by a function of the elements of the iterable, pass a *key*
1608     function. Its arguments are the elements of the iterables corresponding to
1609     the key list::
1610
1611         >>> names = ('a', 'b', 'c')
1612         >>> lengths = (1, 2, 3)
1613         >>> widths = (5, 2, 1)
1614         >>> def area(length, width):
1615         ...     return length * width
1616         >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area)
1617         [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)]
1618
1619     Set *reverse* to ``True`` to sort in descending order.
1620
1621         >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
1622         [(3, 2, 1), ('a', 'b', 'c')]
1623
1624     """
1625     if key is None:
1626         # if there is no key function, the key argument to sorted is an
1627         # itemgetter
1628         key_argument = itemgetter(*key_list)
1629     else:
1630         # if there is a key function, call it with the items at the offsets
1631         # specified by the key function as arguments
1632         key_list = list(key_list)
1633         if len(key_list) == 1:
1634             # if key_list contains a single item, pass the item at that offset
1635             # as the only argument to the key function
1636             key_offset = key_list[0]
1637             key_argument = lambda zipped_items: key(zipped_items[key_offset])
1638         else:
1639             # if key_list contains multiple items, use itemgetter to return a
1640             # tuple of items, which we pass as *args to the key function
1641             get_key_items = itemgetter(*key_list)
1642             key_argument = lambda zipped_items: key(
1643                 *get_key_items(zipped_items)
1644             )
1645
1646     return list(
1647         zip(*sorted(zip(*iterables), key=key_argument, reverse=reverse))
1648     )
1649
1650
1651 def unzip(iterable):
1652     """The inverse of :func:`zip`, this function disaggregates the elements
1653     of the zipped *iterable*.
1654
1655     The ``i``-th iterable contains the ``i``-th element from each element
1656     of the zipped iterable. The first element is used to to determine the
1657     length of the remaining elements.
1658
1659         >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1660         >>> letters, numbers = unzip(iterable)
1661         >>> list(letters)
1662         ['a', 'b', 'c', 'd']
1663         >>> list(numbers)
1664         [1, 2, 3, 4]
1665
1666     This is similar to using ``zip(*iterable)``, but it avoids reading
1667     *iterable* into memory. Note, however, that this function uses
1668     :func:`itertools.tee` and thus may require significant storage.
1669
1670     """
1671     head, iterable = spy(iter(iterable))
1672     if not head:
1673         # empty iterable, e.g. zip([], [], [])
1674         return ()
1675     # spy returns a one-length iterable as head
1676     head = head[0]
1677     iterables = tee(iterable, len(head))
1678
1679     def itemgetter(i):
1680         def getter(obj):
1681             try:
1682                 return obj[i]
1683             except IndexError:
1684                 # basically if we have an iterable like
1685                 # iter([(1, 2, 3), (4, 5), (6,)])
1686                 # the second unzipped iterable would fail at the third tuple
1687                 # since it would try to access tup[1]
1688                 # same with the third unzipped iterable and the second tuple
1689                 # to support these "improperly zipped" iterables,
1690                 # we create a custom itemgetter
1691                 # which just stops the unzipped iterables
1692                 # at first length mismatch
1693                 raise StopIteration
1694
1695         return getter
1696
1697     return tuple(map(itemgetter(i), it) for i, it in enumerate(iterables))
1698
1699
1700 def divide(n, iterable):
1701     """Divide the elements from *iterable* into *n* parts, maintaining
1702     order.
1703
1704         >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
1705         >>> list(group_1)
1706         [1, 2, 3]
1707         >>> list(group_2)
1708         [4, 5, 6]
1709
1710     If the length of *iterable* is not evenly divisible by *n*, then the
1711     length of the returned iterables will not be identical:
1712
1713         >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
1714         >>> [list(c) for c in children]
1715         [[1, 2, 3], [4, 5], [6, 7]]
1716
1717     If the length of the iterable is smaller than n, then the last returned
1718     iterables will be empty:
1719
1720         >>> children = divide(5, [1, 2, 3])
1721         >>> [list(c) for c in children]
1722         [[1], [2], [3], [], []]
1723
1724     This function will exhaust the iterable before returning and may require
1725     significant storage. If order is not important, see :func:`distribute`,
1726     which does not first pull the iterable into memory.
1727
1728     """
1729     if n < 1:
1730         raise ValueError('n must be at least 1')
1731
1732     try:
1733         iterable[:0]
1734     except TypeError:
1735         seq = tuple(iterable)
1736     else:
1737         seq = iterable
1738
1739     q, r = divmod(len(seq), n)
1740
1741     ret = []
1742     stop = 0
1743     for i in range(1, n + 1):
1744         start = stop
1745         stop += q + 1 if i <= r else q
1746         ret.append(iter(seq[start:stop]))
1747
1748     return ret
1749
1750
1751 def always_iterable(obj, base_type=(str, bytes)):
1752     """If *obj* is iterable, return an iterator over its items::
1753
1754         >>> obj = (1, 2, 3)
1755         >>> list(always_iterable(obj))
1756         [1, 2, 3]
1757
1758     If *obj* is not iterable, return a one-item iterable containing *obj*::
1759
1760         >>> obj = 1
1761         >>> list(always_iterable(obj))
1762         [1]
1763
1764     If *obj* is ``None``, return an empty iterable:
1765
1766         >>> obj = None
1767         >>> list(always_iterable(None))
1768         []
1769
1770     By default, binary and text strings are not considered iterable::
1771
1772         >>> obj = 'foo'
1773         >>> list(always_iterable(obj))
1774         ['foo']
1775
1776     If *base_type* is set, objects for which ``isinstance(obj, base_type)``
1777     returns ``True`` won't be considered iterable.
1778
1779         >>> obj = {'a': 1}
1780         >>> list(always_iterable(obj))  # Iterate over the dict's keys
1781         ['a']
1782         >>> list(always_iterable(obj, base_type=dict))  # Treat dicts as a unit
1783         [{'a': 1}]
1784
1785     Set *base_type* to ``None`` to avoid any special handling and treat objects
1786     Python considers iterable as iterable:
1787
1788         >>> obj = 'foo'
1789         >>> list(always_iterable(obj, base_type=None))
1790         ['f', 'o', 'o']
1791     """
1792     if obj is None:
1793         return iter(())
1794
1795     if (base_type is not None) and isinstance(obj, base_type):
1796         return iter((obj,))
1797
1798     try:
1799         return iter(obj)
1800     except TypeError:
1801         return iter((obj,))
1802
1803
1804 def adjacent(predicate, iterable, distance=1):
1805     """Return an iterable over `(bool, item)` tuples where the `item` is
1806     drawn from *iterable* and the `bool` indicates whether
1807     that item satisfies the *predicate* or is adjacent to an item that does.
1808
1809     For example, to find whether items are adjacent to a ``3``::
1810
1811         >>> list(adjacent(lambda x: x == 3, range(6)))
1812         [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)]
1813
1814     Set *distance* to change what counts as adjacent. For example, to find
1815     whether items are two places away from a ``3``:
1816
1817         >>> list(adjacent(lambda x: x == 3, range(6), distance=2))
1818         [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)]
1819
1820     This is useful for contextualizing the results of a search function.
1821     For example, a code comparison tool might want to identify lines that
1822     have changed, but also surrounding lines to give the viewer of the diff
1823     context.
1824
1825     The predicate function will only be called once for each item in the
1826     iterable.
1827
1828     See also :func:`groupby_transform`, which can be used with this function
1829     to group ranges of items with the same `bool` value.
1830
1831     """
1832     # Allow distance=0 mainly for testing that it reproduces results with map()
1833     if distance < 0:
1834         raise ValueError('distance must be at least 0')
1835
1836     i1, i2 = tee(iterable)
1837     padding = [False] * distance
1838     selected = chain(padding, map(predicate, i1), padding)
1839     adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1))
1840     return zip(adjacent_to_selected, i2)
1841
1842
1843 def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None):
1844     """An extension of :func:`itertools.groupby` that can apply transformations
1845     to the grouped data.
1846
1847     * *keyfunc* is a function computing a key value for each item in *iterable*
1848     * *valuefunc* is a function that transforms the individual items from
1849       *iterable* after grouping
1850     * *reducefunc* is a function that transforms each group of items
1851
1852     >>> iterable = 'aAAbBBcCC'
1853     >>> keyfunc = lambda k: k.upper()
1854     >>> valuefunc = lambda v: v.lower()
1855     >>> reducefunc = lambda g: ''.join(g)
1856     >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc))
1857     [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')]
1858
1859     Each optional argument defaults to an identity function if not specified.
1860
1861     :func:`groupby_transform` is useful when grouping elements of an iterable
1862     using a separate iterable as the key. To do this, :func:`zip` the iterables
1863     and pass a *keyfunc* that extracts the first element and a *valuefunc*
1864     that extracts the second element::
1865
1866         >>> from operator import itemgetter
1867         >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3]
1868         >>> values = 'abcdefghi'
1869         >>> iterable = zip(keys, values)
1870         >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1))
1871         >>> [(k, ''.join(g)) for k, g in grouper]
1872         [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')]
1873
1874     Note that the order of items in the iterable is significant.
1875     Only adjacent items are grouped together, so if you don't want any
1876     duplicate groups, you should sort the iterable by the key function.
1877
1878     """
1879     ret = groupby(iterable, keyfunc)
1880     if valuefunc:
1881         ret = ((k, map(valuefunc, g)) for k, g in ret)
1882     if reducefunc:
1883         ret = ((k, reducefunc(g)) for k, g in ret)
1884
1885     return ret
1886
1887
1888 class numeric_range(abc.Sequence, abc.Hashable):
1889     """An extension of the built-in ``range()`` function whose arguments can
1890     be any orderable numeric type.
1891
1892     With only *stop* specified, *start* defaults to ``0`` and *step*
1893     defaults to ``1``. The output items will match the type of *stop*:
1894
1895         >>> list(numeric_range(3.5))
1896         [0.0, 1.0, 2.0, 3.0]
1897
1898     With only *start* and *stop* specified, *step* defaults to ``1``. The
1899     output items will match the type of *start*:
1900
1901         >>> from decimal import Decimal
1902         >>> start = Decimal('2.1')
1903         >>> stop = Decimal('5.1')
1904         >>> list(numeric_range(start, stop))
1905         [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')]
1906
1907     With *start*, *stop*, and *step*  specified the output items will match
1908     the type of ``start + step``:
1909
1910         >>> from fractions import Fraction
1911         >>> start = Fraction(1, 2)  # Start at 1/2
1912         >>> stop = Fraction(5, 2)  # End at 5/2
1913         >>> step = Fraction(1, 2)  # Count by 1/2
1914         >>> list(numeric_range(start, stop, step))
1915         [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)]
1916
1917     If *step* is zero, ``ValueError`` is raised. Negative steps are supported:
1918
1919         >>> list(numeric_range(3, -1, -1.0))
1920         [3.0, 2.0, 1.0, 0.0]
1921
1922     Be aware of the limitations of floating point numbers; the representation
1923     of the yielded numbers may be surprising.
1924
1925     ``datetime.datetime`` objects can be used for *start* and *stop*, if *step*
1926     is a ``datetime.timedelta`` object:
1927
1928         >>> import datetime
1929         >>> start = datetime.datetime(2019, 1, 1)
1930         >>> stop = datetime.datetime(2019, 1, 3)
1931         >>> step = datetime.timedelta(days=1)
1932         >>> items = iter(numeric_range(start, stop, step))
1933         >>> next(items)
1934         datetime.datetime(2019, 1, 1, 0, 0)
1935         >>> next(items)
1936         datetime.datetime(2019, 1, 2, 0, 0)
1937
1938     """
1939
1940     _EMPTY_HASH = hash(range(0, 0))
1941
1942     def __init__(self, *args):
1943         argc = len(args)
1944         if argc == 1:
1945             (self._stop,) = args
1946             self._start = type(self._stop)(0)
1947             self._step = type(self._stop - self._start)(1)
1948         elif argc == 2:
1949             self._start, self._stop = args
1950             self._step = type(self._stop - self._start)(1)
1951         elif argc == 3:
1952             self._start, self._stop, self._step = args
1953         elif argc == 0:
1954             raise TypeError(
1955                 'numeric_range expected at least '
1956                 '1 argument, got {}'.format(argc)
1957             )
1958         else:
1959             raise TypeError(
1960                 'numeric_range expected at most '
1961                 '3 arguments, got {}'.format(argc)
1962             )
1963
1964         self._zero = type(self._step)(0)
1965         if self._step == self._zero:
1966             raise ValueError('numeric_range() arg 3 must not be zero')
1967         self._growing = self._step > self._zero
1968         self._init_len()
1969
1970     def __bool__(self):
1971         if self._growing:
1972             return self._start < self._stop
1973         else:
1974             return self._start > self._stop
1975
1976     def __contains__(self, elem):
1977         if self._growing:
1978             if self._start <= elem < self._stop:
1979                 return (elem - self._start) % self._step == self._zero
1980         else:
1981             if self._start >= elem > self._stop:
1982                 return (self._start - elem) % (-self._step) == self._zero
1983
1984         return False
1985
1986     def __eq__(self, other):
1987         if isinstance(other, numeric_range):
1988             empty_self = not bool(self)
1989             empty_other = not bool(other)
1990             if empty_self or empty_other:
1991                 return empty_self and empty_other  # True if both empty
1992             else:
1993                 return (
1994                     self._start == other._start
1995                     and self._step == other._step
1996                     and self._get_by_index(-1) == other._get_by_index(-1)
1997                 )
1998         else:
1999             return False
2000
2001     def __getitem__(self, key):
2002         if isinstance(key, int):
2003             return self._get_by_index(key)
2004         elif isinstance(key, slice):
2005             step = self._step if key.step is None else key.step * self._step
2006
2007             if key.start is None or key.start <= -self._len:
2008                 start = self._start
2009             elif key.start >= self._len:
2010                 start = self._stop
2011             else:  # -self._len < key.start < self._len
2012                 start = self._get_by_index(key.start)
2013
2014             if key.stop is None or key.stop >= self._len:
2015                 stop = self._stop
2016             elif key.stop <= -self._len:
2017                 stop = self._start
2018             else:  # -self._len < key.stop < self._len
2019                 stop = self._get_by_index(key.stop)
2020
2021             return numeric_range(start, stop, step)
2022         else:
2023             raise TypeError(
2024                 'numeric range indices must be '
2025                 'integers or slices, not {}'.format(type(key).__name__)
2026             )
2027
2028     def __hash__(self):
2029         if self:
2030             return hash((self._start, self._get_by_index(-1), self._step))
2031         else:
2032             return self._EMPTY_HASH
2033
2034     def __iter__(self):
2035         values = (self._start + (n * self._step) for n in count())
2036         if self._growing:
2037             return takewhile(partial(gt, self._stop), values)
2038         else:
2039             return takewhile(partial(lt, self._stop), values)
2040
2041     def __len__(self):
2042         return self._len
2043
2044     def _init_len(self):
2045         if self._growing:
2046             start = self._start
2047             stop = self._stop
2048             step = self._step
2049         else:
2050             start = self._stop
2051             stop = self._start
2052             step = -self._step
2053         distance = stop - start
2054         if distance <= self._zero:
2055             self._len = 0
2056         else:  # distance > 0 and step > 0: regular euclidean division
2057             q, r = divmod(distance, step)
2058             self._len = int(q) + int(r != self._zero)
2059
2060     def __reduce__(self):
2061         return numeric_range, (self._start, self._stop, self._step)
2062
2063     def __repr__(self):
2064         if self._step == 1:
2065             return "numeric_range({}, {})".format(
2066                 repr(self._start), repr(self._stop)
2067             )
2068         else:
2069             return "numeric_range({}, {}, {})".format(
2070                 repr(self._start), repr(self._stop), repr(self._step)
2071             )
2072
2073     def __reversed__(self):
2074         return iter(
2075             numeric_range(
2076                 self._get_by_index(-1), self._start - self._step, -self._step
2077             )
2078         )
2079
2080     def count(self, value):
2081         return int(value in self)
2082
2083     def index(self, value):
2084         if self._growing:
2085             if self._start <= value < self._stop:
2086                 q, r = divmod(value - self._start, self._step)
2087                 if r == self._zero:
2088                     return int(q)
2089         else:
2090             if self._start >= value > self._stop:
2091                 q, r = divmod(self._start - value, -self._step)
2092                 if r == self._zero:
2093                     return int(q)
2094
2095         raise ValueError("{} is not in numeric range".format(value))
2096
2097     def _get_by_index(self, i):
2098         if i < 0:
2099             i += self._len
2100         if i < 0 or i >= self._len:
2101             raise IndexError("numeric range object index out of range")
2102         return self._start + i * self._step
2103
2104
2105 def count_cycle(iterable, n=None):
2106     """Cycle through the items from *iterable* up to *n* times, yielding
2107     the number of completed cycles along with each item. If *n* is omitted the
2108     process repeats indefinitely.
2109
2110     >>> list(count_cycle('AB', 3))
2111     [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')]
2112
2113     """
2114     iterable = tuple(iterable)
2115     if not iterable:
2116         return iter(())
2117     counter = count() if n is None else range(n)
2118     return ((i, item) for i in counter for item in iterable)
2119
2120
2121 def mark_ends(iterable):
2122     """Yield 3-tuples of the form ``(is_first, is_last, item)``.
2123
2124     >>> list(mark_ends('ABC'))
2125     [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')]
2126
2127     Use this when looping over an iterable to take special action on its first
2128     and/or last items:
2129
2130     >>> iterable = ['Header', 100, 200, 'Footer']
2131     >>> total = 0
2132     >>> for is_first, is_last, item in mark_ends(iterable):
2133     ...     if is_first:
2134     ...         continue  # Skip the header
2135     ...     if is_last:
2136     ...         continue  # Skip the footer
2137     ...     total += item
2138     >>> print(total)
2139     300
2140     """
2141     it = iter(iterable)
2142
2143     try:
2144         b = next(it)
2145     except StopIteration:
2146         return
2147
2148     try:
2149         for i in count():
2150             a = b
2151             b = next(it)
2152             yield i == 0, False, a
2153
2154     except StopIteration:
2155         yield i == 0, True, a
2156
2157
2158 def locate(iterable, pred=bool, window_size=None):
2159     """Yield the index of each item in *iterable* for which *pred* returns
2160     ``True``.
2161
2162     *pred* defaults to :func:`bool`, which will select truthy items:
2163
2164         >>> list(locate([0, 1, 1, 0, 1, 0, 0]))
2165         [1, 2, 4]
2166
2167     Set *pred* to a custom function to, e.g., find the indexes for a particular
2168     item.
2169
2170         >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
2171         [1, 3]
2172
2173     If *window_size* is given, then the *pred* function will be called with
2174     that many items. This enables searching for sub-sequences:
2175
2176         >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2177         >>> pred = lambda *args: args == (1, 2, 3)
2178         >>> list(locate(iterable, pred=pred, window_size=3))
2179         [1, 5, 9]
2180
2181     Use with :func:`seekable` to find indexes and then retrieve the associated
2182     items:
2183
2184         >>> from itertools import count
2185         >>> from more_itertools import seekable
2186         >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count())
2187         >>> it = seekable(source)
2188         >>> pred = lambda x: x > 100
2189         >>> indexes = locate(it, pred=pred)
2190         >>> i = next(indexes)
2191         >>> it.seek(i)
2192         >>> next(it)
2193         106
2194
2195     """
2196     if window_size is None:
2197         return compress(count(), map(pred, iterable))
2198
2199     if window_size < 1:
2200         raise ValueError('window size must be at least 1')
2201
2202     it = windowed(iterable, window_size, fillvalue=_marker)
2203     return compress(count(), starmap(pred, it))
2204
2205
2206 def lstrip(iterable, pred):
2207     """Yield the items from *iterable*, but strip any from the beginning
2208     for which *pred* returns ``True``.
2209
2210     For example, to remove a set of items from the start of an iterable:
2211
2212         >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2213         >>> pred = lambda x: x in {None, False, ''}
2214         >>> list(lstrip(iterable, pred))
2215         [1, 2, None, 3, False, None]
2216
2217     This function is analogous to to :func:`str.lstrip`, and is essentially
2218     an wrapper for :func:`itertools.dropwhile`.
2219
2220     """
2221     return dropwhile(pred, iterable)
2222
2223
2224 def rstrip(iterable, pred):
2225     """Yield the items from *iterable*, but strip any from the end
2226     for which *pred* returns ``True``.
2227
2228     For example, to remove a set of items from the end of an iterable:
2229
2230         >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2231         >>> pred = lambda x: x in {None, False, ''}
2232         >>> list(rstrip(iterable, pred))
2233         [None, False, None, 1, 2, None, 3]
2234
2235     This function is analogous to :func:`str.rstrip`.
2236
2237     """
2238     cache = []
2239     cache_append = cache.append
2240     cache_clear = cache.clear
2241     for x in iterable:
2242         if pred(x):
2243             cache_append(x)
2244         else:
2245             yield from cache
2246             cache_clear()
2247             yield x
2248
2249
2250 def strip(iterable, pred):
2251     """Yield the items from *iterable*, but strip any from the
2252     beginning and end for which *pred* returns ``True``.
2253
2254     For example, to remove a set of items from both ends of an iterable:
2255
2256         >>> iterable = (None, False, None, 1, 2, None, 3, False, None)
2257         >>> pred = lambda x: x in {None, False, ''}
2258         >>> list(strip(iterable, pred))
2259         [1, 2, None, 3]
2260
2261     This function is analogous to :func:`str.strip`.
2262
2263     """
2264     return rstrip(lstrip(iterable, pred), pred)
2265
2266
2267 class islice_extended:
2268     """An extension of :func:`itertools.islice` that supports negative values
2269     for *stop*, *start*, and *step*.
2270
2271         >>> iterable = iter('abcdefgh')
2272         >>> list(islice_extended(iterable, -4, -1))
2273         ['e', 'f', 'g']
2274
2275     Slices with negative values require some caching of *iterable*, but this
2276     function takes care to minimize the amount of memory required.
2277
2278     For example, you can use a negative step with an infinite iterator:
2279
2280         >>> from itertools import count
2281         >>> list(islice_extended(count(), 110, 99, -2))
2282         [110, 108, 106, 104, 102, 100]
2283
2284     You can also use slice notation directly:
2285
2286         >>> iterable = map(str, count())
2287         >>> it = islice_extended(iterable)[10:20:2]
2288         >>> list(it)
2289         ['10', '12', '14', '16', '18']
2290
2291     """
2292
2293     def __init__(self, iterable, *args):
2294         it = iter(iterable)
2295         if args:
2296             self._iterable = _islice_helper(it, slice(*args))
2297         else:
2298             self._iterable = it
2299
2300     def __iter__(self):
2301         return self
2302
2303     def __next__(self):
2304         return next(self._iterable)
2305
2306     def __getitem__(self, key):
2307         if isinstance(key, slice):
2308             return islice_extended(_islice_helper(self._iterable, key))
2309
2310         raise TypeError('islice_extended.__getitem__ argument must be a slice')
2311
2312
2313 def _islice_helper(it, s):
2314     start = s.start
2315     stop = s.stop
2316     if s.step == 0:
2317         raise ValueError('step argument must be a non-zero integer or None.')
2318     step = s.step or 1
2319
2320     if step > 0:
2321         start = 0 if (start is None) else start
2322
2323         if start < 0:
2324             # Consume all but the last -start items
2325             cache = deque(enumerate(it, 1), maxlen=-start)
2326             len_iter = cache[-1][0] if cache else 0
2327
2328             # Adjust start to be positive
2329             i = max(len_iter + start, 0)
2330
2331             # Adjust stop to be positive
2332             if stop is None:
2333                 j = len_iter
2334             elif stop >= 0:
2335                 j = min(stop, len_iter)
2336             else:
2337                 j = max(len_iter + stop, 0)
2338
2339             # Slice the cache
2340             n = j - i
2341             if n <= 0:
2342                 return
2343
2344             for index, item in islice(cache, 0, n, step):
2345                 yield item
2346         elif (stop is not None) and (stop < 0):
2347             # Advance to the start position
2348             next(islice(it, start, start), None)
2349
2350             # When stop is negative, we have to carry -stop items while
2351             # iterating
2352             cache = deque(islice(it, -stop), maxlen=-stop)
2353
2354             for index, item in enumerate(it):
2355                 cached_item = cache.popleft()
2356                 if index % step == 0:
2357                     yield cached_item
2358                 cache.append(item)
2359         else:
2360             # When both start and stop are positive we have the normal case
2361             yield from islice(it, start, stop, step)
2362     else:
2363         start = -1 if (start is None) else start
2364
2365         if (stop is not None) and (stop < 0):
2366             # Consume all but the last items
2367             n = -stop - 1
2368             cache = deque(enumerate(it, 1), maxlen=n)
2369             len_iter = cache[-1][0] if cache else 0
2370
2371             # If start and stop are both negative they are comparable and
2372             # we can just slice. Otherwise we can adjust start to be negative
2373             # and then slice.
2374             if start < 0:
2375                 i, j = start, stop
2376             else:
2377                 i, j = min(start - len_iter, -1), None
2378
2379             for index, item in list(cache)[i:j:step]:
2380                 yield item
2381         else:
2382             # Advance to the stop position
2383             if stop is not None:
2384                 m = stop + 1
2385                 next(islice(it, m, m), None)
2386
2387             # stop is positive, so if start is negative they are not comparable
2388             # and we need the rest of the items.
2389             if start < 0:
2390                 i = start
2391                 n = None
2392             # stop is None and start is positive, so we just need items up to
2393             # the start index.
2394             elif stop is None:
2395                 i = None
2396                 n = start + 1
2397             # Both stop and start are positive, so they are comparable.
2398             else:
2399                 i = None
2400                 n = start - stop
2401                 if n <= 0:
2402                     return
2403
2404             cache = list(islice(it, n))
2405
2406             yield from cache[i::step]
2407
2408
2409 def always_reversible(iterable):
2410     """An extension of :func:`reversed` that supports all iterables, not
2411     just those which implement the ``Reversible`` or ``Sequence`` protocols.
2412
2413         >>> print(*always_reversible(x for x in range(3)))
2414         2 1 0
2415
2416     If the iterable is already reversible, this function returns the
2417     result of :func:`reversed()`. If the iterable is not reversible,
2418     this function will cache the remaining items in the iterable and
2419     yield them in reverse order, which may require significant storage.
2420     """
2421     try:
2422         return reversed(iterable)
2423     except TypeError:
2424         return reversed(list(iterable))
2425
2426
2427 def consecutive_groups(iterable, ordering=lambda x: x):
2428     """Yield groups of consecutive items using :func:`itertools.groupby`.
2429     The *ordering* function determines whether two items are adjacent by
2430     returning their position.
2431
2432     By default, the ordering function is the identity function. This is
2433     suitable for finding runs of numbers:
2434
2435         >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40]
2436         >>> for group in consecutive_groups(iterable):
2437         ...     print(list(group))
2438         [1]
2439         [10, 11, 12]
2440         [20]
2441         [30, 31, 32, 33]
2442         [40]
2443
2444     For finding runs of adjacent letters, try using the :meth:`index` method
2445     of a string of letters:
2446
2447         >>> from string import ascii_lowercase
2448         >>> iterable = 'abcdfgilmnop'
2449         >>> ordering = ascii_lowercase.index
2450         >>> for group in consecutive_groups(iterable, ordering):
2451         ...     print(list(group))
2452         ['a', 'b', 'c', 'd']
2453         ['f', 'g']
2454         ['i']
2455         ['l', 'm', 'n', 'o', 'p']
2456
2457     Each group of consecutive items is an iterator that shares it source with
2458     *iterable*. When an an output group is advanced, the previous group is
2459     no longer available unless its elements are copied (e.g., into a ``list``).
2460
2461         >>> iterable = [1, 2, 11, 12, 21, 22]
2462         >>> saved_groups = []
2463         >>> for group in consecutive_groups(iterable):
2464         ...     saved_groups.append(list(group))  # Copy group elements
2465         >>> saved_groups
2466         [[1, 2], [11, 12], [21, 22]]
2467
2468     """
2469     for k, g in groupby(
2470         enumerate(iterable), key=lambda x: x[0] - ordering(x[1])
2471     ):
2472         yield map(itemgetter(1), g)
2473
2474
2475 def difference(iterable, func=sub, *, initial=None):
2476     """This function is the inverse of :func:`itertools.accumulate`. By default
2477     it will compute the first difference of *iterable* using
2478     :func:`operator.sub`:
2479
2480         >>> from itertools import accumulate
2481         >>> iterable = accumulate([0, 1, 2, 3, 4])  # produces 0, 1, 3, 6, 10
2482         >>> list(difference(iterable))
2483         [0, 1, 2, 3, 4]
2484
2485     *func* defaults to :func:`operator.sub`, but other functions can be
2486     specified. They will be applied as follows::
2487
2488         A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ...
2489
2490     For example, to do progressive division:
2491
2492         >>> iterable = [1, 2, 6, 24, 120]
2493         >>> func = lambda x, y: x // y
2494         >>> list(difference(iterable, func))
2495         [1, 2, 3, 4, 5]
2496
2497     If the *initial* keyword is set, the first element will be skipped when
2498     computing successive differences.
2499
2500         >>> it = [10, 11, 13, 16]  # from accumulate([1, 2, 3], initial=10)
2501         >>> list(difference(it, initial=10))
2502         [1, 2, 3]
2503
2504     """
2505     a, b = tee(iterable)
2506     try:
2507         first = [next(b)]
2508     except StopIteration:
2509         return iter([])
2510
2511     if initial is not None:
2512         first = []
2513
2514     return chain(first, starmap(func, zip(b, a)))
2515
2516
2517 class SequenceView(Sequence):
2518     """Return a read-only view of the sequence object *target*.
2519
2520     :class:`SequenceView` objects are analogous to Python's built-in
2521     "dictionary view" types. They provide a dynamic view of a sequence's items,
2522     meaning that when the sequence updates, so does the view.
2523
2524         >>> seq = ['0', '1', '2']
2525         >>> view = SequenceView(seq)
2526         >>> view
2527         SequenceView(['0', '1', '2'])
2528         >>> seq.append('3')
2529         >>> view
2530         SequenceView(['0', '1', '2', '3'])
2531
2532     Sequence views support indexing, slicing, and length queries. They act
2533     like the underlying sequence, except they don't allow assignment:
2534
2535         >>> view[1]
2536         '1'
2537         >>> view[1:-1]
2538         ['1', '2']
2539         >>> len(view)
2540         4
2541
2542     Sequence views are useful as an alternative to copying, as they don't
2543     require (much) extra storage.
2544
2545     """
2546
2547     def __init__(self, target):
2548         if not isinstance(target, Sequence):
2549             raise TypeError
2550         self._target = target
2551
2552     def __getitem__(self, index):
2553         return self._target[index]
2554
2555     def __len__(self):
2556         return len(self._target)
2557
2558     def __repr__(self):
2559         return '{}({})'.format(self.__class__.__name__, repr(self._target))
2560
2561
2562 class seekable:
2563     """Wrap an iterator to allow for seeking backward and forward. This
2564     progressively caches the items in the source iterable so they can be
2565     re-visited.
2566
2567     Call :meth:`seek` with an index to seek to that position in the source
2568     iterable.
2569
2570     To "reset" an iterator, seek to ``0``:
2571
2572         >>> from itertools import count
2573         >>> it = seekable((str(n) for n in count()))
2574         >>> next(it), next(it), next(it)
2575         ('0', '1', '2')
2576         >>> it.seek(0)
2577         >>> next(it), next(it), next(it)
2578         ('0', '1', '2')
2579         >>> next(it)
2580         '3'
2581
2582     You can also seek forward:
2583
2584         >>> it = seekable((str(n) for n in range(20)))
2585         >>> it.seek(10)
2586         >>> next(it)
2587         '10'
2588         >>> it.seek(20)  # Seeking past the end of the source isn't a problem
2589         >>> list(it)
2590         []
2591         >>> it.seek(0)  # Resetting works even after hitting the end
2592         >>> next(it), next(it), next(it)
2593         ('0', '1', '2')
2594
2595     Call :meth:`peek` to look ahead one item without advancing the iterator:
2596
2597         >>> it = seekable('1234')
2598         >>> it.peek()
2599         '1'
2600         >>> list(it)
2601         ['1', '2', '3', '4']
2602         >>> it.peek(default='empty')
2603         'empty'
2604
2605     Before the iterator is at its end, calling :func:`bool` on it will return
2606     ``True``. After it will return ``False``:
2607
2608         >>> it = seekable('5678')
2609         >>> bool(it)
2610         True
2611         >>> list(it)
2612         ['5', '6', '7', '8']
2613         >>> bool(it)
2614         False
2615
2616     You may view the contents of the cache with the :meth:`elements` method.
2617     That returns a :class:`SequenceView`, a view that updates automatically:
2618
2619         >>> it = seekable((str(n) for n in range(10)))
2620         >>> next(it), next(it), next(it)
2621         ('0', '1', '2')
2622         >>> elements = it.elements()
2623         >>> elements
2624         SequenceView(['0', '1', '2'])
2625         >>> next(it)
2626         '3'
2627         >>> elements
2628         SequenceView(['0', '1', '2', '3'])
2629
2630     By default, the cache grows as the source iterable progresses, so beware of
2631     wrapping very large or infinite iterables. Supply *maxlen* to limit the
2632     size of the cache (this of course limits how far back you can seek).
2633
2634         >>> from itertools import count
2635         >>> it = seekable((str(n) for n in count()), maxlen=2)
2636         >>> next(it), next(it), next(it), next(it)
2637         ('0', '1', '2', '3')
2638         >>> list(it.elements())
2639         ['2', '3']
2640         >>> it.seek(0)
2641         >>> next(it), next(it), next(it), next(it)
2642         ('2', '3', '4', '5')
2643         >>> next(it)
2644         '6'
2645
2646     """
2647
2648     def __init__(self, iterable, maxlen=None):
2649         self._source = iter(iterable)
2650         if maxlen is None:
2651             self._cache = []
2652         else:
2653             self._cache = deque([], maxlen)
2654         self._index = None
2655
2656     def __iter__(self):
2657         return self
2658
2659     def __next__(self):
2660         if self._index is not None:
2661             try:
2662                 item = self._cache[self._index]
2663             except IndexError:
2664                 self._index = None
2665             else:
2666                 self._index += 1
2667                 return item
2668
2669         item = next(self._source)
2670         self._cache.append(item)
2671         return item
2672
2673     def __bool__(self):
2674         try:
2675             self.peek()
2676         except StopIteration:
2677             return False
2678         return True
2679
2680     def peek(self, default=_marker):
2681         try:
2682             peeked = next(self)
2683         except StopIteration:
2684             if default is _marker:
2685                 raise
2686             return default
2687         if self._index is None:
2688             self._index = len(self._cache)
2689         self._index -= 1
2690         return peeked
2691
2692     def elements(self):
2693         return SequenceView(self._cache)
2694
2695     def seek(self, index):
2696         self._index = index
2697         remainder = index - len(self._cache)
2698         if remainder > 0:
2699             consume(self, remainder)
2700
2701
2702 class run_length:
2703     """
2704     :func:`run_length.encode` compresses an iterable with run-length encoding.
2705     It yields groups of repeated items with the count of how many times they
2706     were repeated:
2707
2708         >>> uncompressed = 'abbcccdddd'
2709         >>> list(run_length.encode(uncompressed))
2710         [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
2711
2712     :func:`run_length.decode` decompresses an iterable that was previously
2713     compressed with run-length encoding. It yields the items of the
2714     decompressed iterable:
2715
2716         >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
2717         >>> list(run_length.decode(compressed))
2718         ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd']
2719
2720     """
2721
2722     @staticmethod
2723     def encode(iterable):
2724         return ((k, ilen(g)) for k, g in groupby(iterable))
2725
2726     @staticmethod
2727     def decode(iterable):
2728         return chain.from_iterable(repeat(k, n) for k, n in iterable)
2729
2730
2731 def exactly_n(iterable, n, predicate=bool):
2732     """Return ``True`` if exactly ``n`` items in the iterable are ``True``
2733     according to the *predicate* function.
2734
2735         >>> exactly_n([True, True, False], 2)
2736         True
2737         >>> exactly_n([True, True, False], 1)
2738         False
2739         >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3)
2740         True
2741
2742     The iterable will be advanced until ``n + 1`` truthy items are encountered,
2743     so avoid calling it on infinite iterables.
2744
2745     """
2746     return len(take(n + 1, filter(predicate, iterable))) == n
2747
2748
2749 def circular_shifts(iterable):
2750     """Return a list of circular shifts of *iterable*.
2751
2752     >>> circular_shifts(range(4))
2753     [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
2754     """
2755     lst = list(iterable)
2756     return take(len(lst), windowed(cycle(lst), len(lst)))
2757
2758
2759 def make_decorator(wrapping_func, result_index=0):
2760     """Return a decorator version of *wrapping_func*, which is a function that
2761     modifies an iterable. *result_index* is the position in that function's
2762     signature where the iterable goes.
2763
2764     This lets you use itertools on the "production end," i.e. at function
2765     definition. This can augment what the function returns without changing the
2766     function's code.
2767
2768     For example, to produce a decorator version of :func:`chunked`:
2769
2770         >>> from more_itertools import chunked
2771         >>> chunker = make_decorator(chunked, result_index=0)
2772         >>> @chunker(3)
2773         ... def iter_range(n):
2774         ...     return iter(range(n))
2775         ...
2776         >>> list(iter_range(9))
2777         [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
2778
2779     To only allow truthy items to be returned:
2780
2781         >>> truth_serum = make_decorator(filter, result_index=1)
2782         >>> @truth_serum(bool)
2783         ... def boolean_test():
2784         ...     return [0, 1, '', ' ', False, True]
2785         ...
2786         >>> list(boolean_test())
2787         [1, ' ', True]
2788
2789     The :func:`peekable` and :func:`seekable` wrappers make for practical
2790     decorators:
2791
2792         >>> from more_itertools import peekable
2793         >>> peekable_function = make_decorator(peekable)
2794         >>> @peekable_function()
2795         ... def str_range(*args):
2796         ...     return (str(x) for x in range(*args))
2797         ...
2798         >>> it = str_range(1, 20, 2)
2799         >>> next(it), next(it), next(it)
2800         ('1', '3', '5')
2801         >>> it.peek()
2802         '7'
2803         >>> next(it)
2804         '7'
2805
2806     """
2807     # See https://sites.google.com/site/bbayles/index/decorator_factory for
2808     # notes on how this works.
2809     def decorator(*wrapping_args, **wrapping_kwargs):
2810         def outer_wrapper(f):
2811             def inner_wrapper(*args, **kwargs):
2812                 result = f(*args, **kwargs)
2813                 wrapping_args_ = list(wrapping_args)
2814                 wrapping_args_.insert(result_index, result)
2815                 return wrapping_func(*wrapping_args_, **wrapping_kwargs)
2816
2817             return inner_wrapper
2818
2819         return outer_wrapper
2820
2821     return decorator
2822
2823
2824 def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
2825     """Return a dictionary that maps the items in *iterable* to categories
2826     defined by *keyfunc*, transforms them with *valuefunc*, and
2827     then summarizes them by category with *reducefunc*.
2828
2829     *valuefunc* defaults to the identity function if it is unspecified.
2830     If *reducefunc* is unspecified, no summarization takes place:
2831
2832         >>> keyfunc = lambda x: x.upper()
2833         >>> result = map_reduce('abbccc', keyfunc)
2834         >>> sorted(result.items())
2835         [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])]
2836
2837     Specifying *valuefunc* transforms the categorized items:
2838
2839         >>> keyfunc = lambda x: x.upper()
2840         >>> valuefunc = lambda x: 1
2841         >>> result = map_reduce('abbccc', keyfunc, valuefunc)
2842         >>> sorted(result.items())
2843         [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])]
2844
2845     Specifying *reducefunc* summarizes the categorized items:
2846
2847         >>> keyfunc = lambda x: x.upper()
2848         >>> valuefunc = lambda x: 1
2849         >>> reducefunc = sum
2850         >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc)
2851         >>> sorted(result.items())
2852         [('A', 1), ('B', 2), ('C', 3)]
2853
2854     You may want to filter the input iterable before applying the map/reduce
2855     procedure:
2856
2857         >>> all_items = range(30)
2858         >>> items = [x for x in all_items if 10 <= x <= 20]  # Filter
2859         >>> keyfunc = lambda x: x % 2  # Evens map to 0; odds to 1
2860         >>> categories = map_reduce(items, keyfunc=keyfunc)
2861         >>> sorted(categories.items())
2862         [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])]
2863         >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum)
2864         >>> sorted(summaries.items())
2865         [(0, 90), (1, 75)]
2866
2867     Note that all items in the iterable are gathered into a list before the
2868     summarization step, which may require significant storage.
2869
2870     The returned object is a :obj:`collections.defaultdict` with the
2871     ``default_factory`` set to ``None``, such that it behaves like a normal
2872     dictionary.
2873
2874     """
2875     valuefunc = (lambda x: x) if (valuefunc is None) else valuefunc
2876
2877     ret = defaultdict(list)
2878     for item in iterable:
2879         key = keyfunc(item)
2880         value = valuefunc(item)
2881         ret[key].append(value)
2882
2883     if reducefunc is not None:
2884         for key, value_list in ret.items():
2885             ret[key] = reducefunc(value_list)
2886
2887     ret.default_factory = None
2888     return ret
2889
2890
2891 def rlocate(iterable, pred=bool, window_size=None):
2892     """Yield the index of each item in *iterable* for which *pred* returns
2893     ``True``, starting from the right and moving left.
2894
2895     *pred* defaults to :func:`bool`, which will select truthy items:
2896
2897         >>> list(rlocate([0, 1, 1, 0, 1, 0, 0]))  # Truthy at 1, 2, and 4
2898         [4, 2, 1]
2899
2900     Set *pred* to a custom function to, e.g., find the indexes for a particular
2901     item:
2902
2903         >>> iterable = iter('abcb')
2904         >>> pred = lambda x: x == 'b'
2905         >>> list(rlocate(iterable, pred))
2906         [3, 1]
2907
2908     If *window_size* is given, then the *pred* function will be called with
2909     that many items. This enables searching for sub-sequences:
2910
2911         >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
2912         >>> pred = lambda *args: args == (1, 2, 3)
2913         >>> list(rlocate(iterable, pred=pred, window_size=3))
2914         [9, 5, 1]
2915
2916     Beware, this function won't return anything for infinite iterables.
2917     If *iterable* is reversible, ``rlocate`` will reverse it and search from
2918     the right. Otherwise, it will search from the left and return the results
2919     in reverse order.
2920
2921     See :func:`locate` to for other example applications.
2922
2923     """
2924     if window_size is None:
2925         try:
2926             len_iter = len(iterable)
2927             return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
2928         except TypeError:
2929             pass
2930
2931     return reversed(list(locate(iterable, pred, window_size)))
2932
2933
2934 def replace(iterable, pred, substitutes, count=None, window_size=1):
2935     """Yield the items from *iterable*, replacing the items for which *pred*
2936     returns ``True`` with the items from the iterable *substitutes*.
2937
2938         >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
2939         >>> pred = lambda x: x == 0
2940         >>> substitutes = (2, 3)
2941         >>> list(replace(iterable, pred, substitutes))
2942         [1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
2943
2944     If *count* is given, the number of replacements will be limited:
2945
2946         >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
2947         >>> pred = lambda x: x == 0
2948         >>> substitutes = [None]
2949         >>> list(replace(iterable, pred, substitutes, count=2))
2950         [1, 1, None, 1, 1, None, 1, 1, 0]
2951
2952     Use *window_size* to control the number of items passed as arguments to
2953     *pred*. This allows for locating and replacing subsequences.
2954
2955         >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
2956         >>> window_size = 3
2957         >>> pred = lambda *args: args == (0, 1, 2)  # 3 items passed to pred
2958         >>> substitutes = [3, 4] # Splice in these items
2959         >>> list(replace(iterable, pred, substitutes, window_size=window_size))
2960         [3, 4, 5, 3, 4, 5]
2961
2962     """
2963     if window_size < 1:
2964         raise ValueError('window_size must be at least 1')
2965
2966     # Save the substitutes iterable, since it's used more than once
2967     substitutes = tuple(substitutes)
2968
2969     # Add padding such that the number of windows matches the length of the
2970     # iterable
2971     it = chain(iterable, [_marker] * (window_size - 1))
2972     windows = windowed(it, window_size)
2973
2974     n = 0
2975     for w in windows:
2976         # If the current window matches our predicate (and we haven't hit
2977         # our maximum number of replacements), splice in the substitutes
2978         # and then consume the following windows that overlap with this one.
2979         # For example, if the iterable is (0, 1, 2, 3, 4...)
2980         # and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
2981         # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
2982         if pred(*w):
2983             if (count is None) or (n < count):
2984                 n += 1
2985                 yield from substitutes
2986                 consume(windows, window_size - 1)
2987                 continue
2988
2989         # If there was no match (or we've reached the replacement limit),
2990         # yield the first item from the window.
2991         if w and (w[0] is not _marker):
2992             yield w[0]
2993
2994
2995 def partitions(iterable):
2996     """Yield all possible order-preserving partitions of *iterable*.
2997
2998     >>> iterable = 'abc'
2999     >>> for part in partitions(iterable):
3000     ...     print([''.join(p) for p in part])
3001     ['abc']
3002     ['a', 'bc']
3003     ['ab', 'c']
3004     ['a', 'b', 'c']
3005
3006     This is unrelated to :func:`partition`.
3007
3008     """
3009     sequence = list(iterable)
3010     n = len(sequence)
3011     for i in powerset(range(1, n)):
3012         yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
3013
3014
3015 def set_partitions(iterable, k=None):
3016     """
3017     Yield the set partitions of *iterable* into *k* parts. Set partitions are
3018     not order-preserving.
3019
3020     >>> iterable = 'abc'
3021     >>> for part in set_partitions(iterable, 2):
3022     ...     print([''.join(p) for p in part])
3023     ['a', 'bc']
3024     ['ab', 'c']
3025     ['b', 'ac']
3026
3027
3028     If *k* is not given, every set partition is generated.
3029
3030     >>> iterable = 'abc'
3031     >>> for part in set_partitions(iterable):
3032     ...     print([''.join(p) for p in part])
3033     ['abc']
3034     ['a', 'bc']
3035     ['ab', 'c']
3036     ['b', 'ac']
3037     ['a', 'b', 'c']
3038
3039     """
3040     L = list(iterable)
3041     n = len(L)
3042     if k is not None:
3043         if k < 1:
3044             raise ValueError(
3045                 "Can't partition in a negative or zero number of groups"
3046             )
3047         elif k > n:
3048             return
3049
3050     def set_partitions_helper(L, k):
3051         n = len(L)
3052         if k == 1:
3053             yield [L]
3054         elif n == k:
3055             yield [[s] for s in L]
3056         else:
3057             e, *M = L
3058             for p in set_partitions_helper(M, k - 1):
3059                 yield [[e], *p]
3060             for p in set_partitions_helper(M, k):
3061                 for i in range(len(p)):
3062                     yield p[:i] + [[e] + p[i]] + p[i + 1 :]
3063
3064     if k is None:
3065         for k in range(1, n + 1):
3066             yield from set_partitions_helper(L, k)
3067     else:
3068         yield from set_partitions_helper(L, k)
3069
3070
3071 class time_limited:
3072     """
3073     Yield items from *iterable* until *limit_seconds* have passed.
3074     If the time limit expires before all items have been yielded, the
3075     ``timed_out`` parameter will be set to ``True``.
3076
3077     >>> from time import sleep
3078     >>> def generator():
3079     ...     yield 1
3080     ...     yield 2
3081     ...     sleep(0.2)
3082     ...     yield 3
3083     >>> iterable = time_limited(0.1, generator())
3084     >>> list(iterable)
3085     [1, 2]
3086     >>> iterable.timed_out
3087     True
3088
3089     Note that the time is checked before each item is yielded, and iteration
3090     stops if  the time elapsed is greater than *limit_seconds*. If your time
3091     limit is 1 second, but it takes 2 seconds to generate the first item from
3092     the iterable, the function will run for 2 seconds and not yield anything.
3093
3094     """
3095
3096     def __init__(self, limit_seconds, iterable):
3097         if limit_seconds < 0:
3098             raise ValueError('limit_seconds must be positive')
3099         self.limit_seconds = limit_seconds
3100         self._iterable = iter(iterable)
3101         self._start_time = monotonic()
3102         self.timed_out = False
3103
3104     def __iter__(self):
3105         return self
3106
3107     def __next__(self):
3108         item = next(self._iterable)
3109         if monotonic() - self._start_time > self.limit_seconds:
3110             self.timed_out = True
3111             raise StopIteration
3112
3113         return item
3114
3115
3116 def only(iterable, default=None, too_long=None):
3117     """If *iterable* has only one item, return it.
3118     If it has zero items, return *default*.
3119     If it has more than one item, raise the exception given by *too_long*,
3120     which is ``ValueError`` by default.
3121
3122     >>> only([], default='missing')
3123     'missing'
3124     >>> only([1])
3125     1
3126     >>> only([1, 2])  # doctest: +IGNORE_EXCEPTION_DETAIL
3127     Traceback (most recent call last):
3128     ...
3129     ValueError: Expected exactly one item in iterable, but got 1, 2,
3130      and perhaps more.'
3131     >>> only([1, 2], too_long=TypeError)  # doctest: +IGNORE_EXCEPTION_DETAIL
3132     Traceback (most recent call last):
3133     ...
3134     TypeError
3135
3136     Note that :func:`only` attempts to advance *iterable* twice to ensure there
3137     is only one item.  See :func:`spy` or :func:`peekable` to check
3138     iterable contents less destructively.
3139     """
3140     it = iter(iterable)
3141     first_value = next(it, default)
3142
3143     try:
3144         second_value = next(it)
3145     except StopIteration:
3146         pass
3147     else:
3148         msg = (
3149             'Expected exactly one item in iterable, but got {!r}, {!r}, '
3150             'and perhaps more.'.format(first_value, second_value)
3151         )
3152         raise too_long or ValueError(msg)
3153
3154     return first_value
3155
3156
3157 def ichunked(iterable, n):
3158     """Break *iterable* into sub-iterables with *n* elements each.
3159     :func:`ichunked` is like :func:`chunked`, but it yields iterables
3160     instead of lists.
3161
3162     If the sub-iterables are read in order, the elements of *iterable*
3163     won't be stored in memory.
3164     If they are read out of order, :func:`itertools.tee` is used to cache
3165     elements as necessary.
3166
3167     >>> from itertools import count
3168     >>> all_chunks = ichunked(count(), 4)
3169     >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks)
3170     >>> list(c_2)  # c_1's elements have been cached; c_3's haven't been
3171     [4, 5, 6, 7]
3172     >>> list(c_1)
3173     [0, 1, 2, 3]
3174     >>> list(c_3)
3175     [8, 9, 10, 11]
3176
3177     """
3178     source = iter(iterable)
3179
3180     while True:
3181         # Check to see whether we're at the end of the source iterable
3182         item = next(source, _marker)
3183         if item is _marker:
3184             return
3185
3186         # Clone the source and yield an n-length slice
3187         source, it = tee(chain([item], source))
3188         yield islice(it, n)
3189
3190         # Advance the source iterable
3191         consume(source, n)
3192
3193
3194 def distinct_combinations(iterable, r):
3195     """Yield the distinct combinations of *r* items taken from *iterable*.
3196
3197         >>> list(distinct_combinations([0, 0, 1], 2))
3198         [(0, 0), (0, 1)]
3199
3200     Equivalent to ``set(combinations(iterable))``, except duplicates are not
3201     generated and thrown away. For larger input sequences this is much more
3202     efficient.
3203
3204     """
3205     if r < 0:
3206         raise ValueError('r must be non-negative')
3207     elif r == 0:
3208         yield ()
3209         return
3210     pool = tuple(iterable)
3211     generators = [unique_everseen(enumerate(pool), key=itemgetter(1))]
3212     current_combo = [None] * r
3213     level = 0
3214     while generators:
3215         try:
3216             cur_idx, p = next(generators[-1])
3217         except StopIteration:
3218             generators.pop()
3219             level -= 1
3220             continue
3221         current_combo[level] = p
3222         if level + 1 == r:
3223             yield tuple(current_combo)
3224         else:
3225             generators.append(
3226                 unique_everseen(
3227                     enumerate(pool[cur_idx + 1 :], cur_idx + 1),
3228                     key=itemgetter(1),
3229                 )
3230             )
3231             level += 1
3232
3233
3234 def filter_except(validator, iterable, *exceptions):
3235     """Yield the items from *iterable* for which the *validator* function does
3236     not raise one of the specified *exceptions*.
3237
3238     *validator* is called for each item in *iterable*.
3239     It should be a function that accepts one argument and raises an exception
3240     if that item is not valid.
3241
3242     >>> iterable = ['1', '2', 'three', '4', None]
3243     >>> list(filter_except(int, iterable, ValueError, TypeError))
3244     ['1', '2', '4']
3245
3246     If an exception other than one given by *exceptions* is raised by
3247     *validator*, it is raised like normal.
3248     """
3249     for item in iterable:
3250         try:
3251             validator(item)
3252         except exceptions:
3253             pass
3254         else:
3255             yield item
3256
3257
3258 def map_except(function, iterable, *exceptions):
3259     """Transform each item from *iterable* with *function* and yield the
3260     result, unless *function* raises one of the specified *exceptions*.
3261
3262     *function* is called to transform each item in *iterable*.
3263     It should be a accept one argument.
3264
3265     >>> iterable = ['1', '2', 'three', '4', None]
3266     >>> list(map_except(int, iterable, ValueError, TypeError))
3267     [1, 2, 4]
3268
3269     If an exception other than one given by *exceptions* is raised by
3270     *function*, it is raised like normal.
3271     """
3272     for item in iterable:
3273         try:
3274             yield function(item)
3275         except exceptions:
3276             pass
3277
3278
3279 def _sample_unweighted(iterable, k):
3280     # Implementation of "Algorithm L" from the 1994 paper by Kim-Hung Li:
3281     # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
3282
3283     # Fill up the reservoir (collection of samples) with the first `k` samples
3284     reservoir = take(k, iterable)
3285
3286     # Generate random number that's the largest in a sample of k U(0,1) numbers
3287     # Largest order statistic: https://en.wikipedia.org/wiki/Order_statistic
3288     W = exp(log(random()) / k)
3289
3290     # The number of elements to skip before changing the reservoir is a random
3291     # number with a geometric distribution. Sample it using random() and logs.
3292     next_index = k + floor(log(random()) / log(1 - W))
3293
3294     for index, element in enumerate(iterable, k):
3295
3296         if index == next_index:
3297             reservoir[randrange(k)] = element
3298             # The new W is the largest in a sample of k U(0, `old_W`) numbers
3299             W *= exp(log(random()) / k)
3300             next_index += floor(log(random()) / log(1 - W)) + 1
3301
3302     return reservoir
3303
3304
3305 def _sample_weighted(iterable, k, weights):
3306     # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
3307     # "Weighted random sampling with a reservoir".
3308
3309     # Log-transform for numerical stability for weights that are small/large
3310     weight_keys = (log(random()) / weight for weight in weights)
3311
3312     # Fill up the reservoir (collection of samples) with the first `k`
3313     # weight-keys and elements, then heapify the list.
3314     reservoir = take(k, zip(weight_keys, iterable))
3315     heapify(reservoir)
3316
3317     # The number of jumps before changing the reservoir is a random variable
3318     # with an exponential distribution. Sample it using random() and logs.
3319     smallest_weight_key, _ = reservoir[0]
3320     weights_to_skip = log(random()) / smallest_weight_key
3321
3322     for weight, element in zip(weights, iterable):
3323         if weight >= weights_to_skip:
3324             # The notation here is consistent with the paper, but we store
3325             # the weight-keys in log-space for better numerical stability.
3326             smallest_weight_key, _ = reservoir[0]
3327             t_w = exp(weight * smallest_weight_key)
3328             r_2 = uniform(t_w, 1)  # generate U(t_w, 1)
3329             weight_key = log(r_2) / weight
3330             heapreplace(reservoir, (weight_key, element))
3331             smallest_weight_key, _ = reservoir[0]
3332             weights_to_skip = log(random()) / smallest_weight_key
3333         else:
3334             weights_to_skip -= weight
3335
3336     # Equivalent to [element for weight_key, element in sorted(reservoir)]
3337     return [heappop(reservoir)[1] for _ in range(k)]
3338
3339
3340 def sample(iterable, k, weights=None):
3341     """Return a *k*-length list of elements chosen (without replacement)
3342     from the *iterable*. Like :func:`random.sample`, but works on iterables
3343     of unknown length.
3344
3345     >>> iterable = range(100)
3346     >>> sample(iterable, 5)  # doctest: +SKIP
3347     [81, 60, 96, 16, 4]
3348
3349     An iterable with *weights* may also be given:
3350
3351     >>> iterable = range(100)
3352     >>> weights = (i * i + 1 for i in range(100))
3353     >>> sampled = sample(iterable, 5, weights=weights)  # doctest: +SKIP
3354     [79, 67, 74, 66, 78]
3355
3356     The algorithm can also be used to generate weighted random permutations.
3357     The relative weight of each item determines the probability that it
3358     appears late in the permutation.
3359
3360     >>> data = "abcdefgh"
3361     >>> weights = range(1, len(data) + 1)
3362     >>> sample(data, k=len(data), weights=weights)  # doctest: +SKIP
3363     ['c', 'a', 'b', 'e', 'g', 'd', 'h', 'f']
3364     """
3365     if k == 0:
3366         return []
3367
3368     iterable = iter(iterable)
3369     if weights is None:
3370         return _sample_unweighted(iterable, k)
3371     else:
3372         weights = iter(weights)
3373         return _sample_weighted(iterable, k, weights)
3374
3375
3376 def is_sorted(iterable, key=None, reverse=False):
3377     """Returns ``True`` if the items of iterable are in sorted order, and
3378     ``False`` otherwise. *key* and *reverse* have the same meaning that they do
3379     in the built-in :func:`sorted` function.
3380
3381     >>> is_sorted(['1', '2', '3', '4', '5'], key=int)
3382     True
3383     >>> is_sorted([5, 4, 3, 1, 2], reverse=True)
3384     False
3385
3386     The function returns ``False`` after encountering the first out-of-order
3387     item. If there are no out-of-order items, the iterable is exhausted.
3388     """
3389
3390     compare = lt if reverse else gt
3391     it = iterable if (key is None) else map(key, iterable)
3392     return not any(starmap(compare, pairwise(it)))
3393
3394
3395 class AbortThread(BaseException):
3396     pass
3397
3398
3399 class callback_iter:
3400     """Convert a function that uses callbacks to an iterator.
3401
3402     Let *func* be a function that takes a `callback` keyword argument.
3403     For example:
3404
3405     >>> def func(callback=None):
3406     ...     for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]:
3407     ...         if callback:
3408     ...             callback(i, c)
3409     ...     return 4
3410
3411
3412     Use ``with callback_iter(func)`` to get an iterator over the parameters
3413     that are delivered to the callback.
3414
3415     >>> with callback_iter(func) as it:
3416     ...     for args, kwargs in it:
3417     ...         print(args)
3418     (1, 'a')
3419     (2, 'b')
3420     (3, 'c')
3421
3422     The function will be called in a background thread. The ``done`` property
3423     indicates whether it has completed execution.
3424
3425     >>> it.done
3426     True
3427
3428     If it completes successfully, its return value will be available
3429     in the ``result`` property.
3430
3431     >>> it.result
3432     4
3433
3434     Notes:
3435
3436     * If the function uses some keyword argument besides ``callback``, supply
3437       *callback_kwd*.
3438     * If it finished executing, but raised an exception, accessing the
3439       ``result`` property will raise the same exception.
3440     * If it hasn't finished executing, accessing the ``result``
3441       property from within the ``with`` block will raise ``RuntimeError``.
3442     * If it hasn't finished executing, accessing the ``result`` property from
3443       outside the ``with`` block will raise a
3444       ``more_itertools.AbortThread`` exception.
3445     * Provide *wait_seconds* to adjust how frequently the it is polled for
3446       output.
3447
3448     """
3449
3450     def __init__(self, func, callback_kwd='callback', wait_seconds=0.1):
3451         self._func = func
3452         self._callback_kwd = callback_kwd
3453         self._aborted = False
3454         self._future = None
3455         self._wait_seconds = wait_seconds
3456         self._executor = __import__("concurrent.futures").futures.ThreadPoolExecutor(max_workers=1)
3457         self._iterator = self._reader()
3458
3459     def __enter__(self):
3460         return self
3461
3462     def __exit__(self, exc_type, exc_value, traceback):
3463         self._aborted = True
3464         self._executor.shutdown()
3465
3466     def __iter__(self):
3467         return self
3468
3469     def __next__(self):
3470         return next(self._iterator)
3471
3472     @property
3473     def done(self):
3474         if self._future is None:
3475             return False
3476         return self._future.done()
3477
3478     @property
3479     def result(self):
3480         if not self.done:
3481             raise RuntimeError('Function has not yet completed')
3482
3483         return self._future.result()
3484
3485     def _reader(self):
3486         q = Queue()
3487
3488         def callback(*args, **kwargs):
3489             if self._aborted:
3490                 raise AbortThread('canceled by user')
3491
3492             q.put((args, kwargs))
3493
3494         self._future = self._executor.submit(
3495             self._func, **{self._callback_kwd: callback}
3496         )
3497
3498         while True:
3499             try:
3500                 item = q.get(timeout=self._wait_seconds)
3501             except Empty:
3502                 pass
3503             else:
3504                 q.task_done()
3505                 yield item
3506
3507             if self._future.done():
3508                 break
3509
3510         remaining = []
3511         while True:
3512             try:
3513                 item = q.get_nowait()
3514             except Empty:
3515                 break
3516             else:
3517                 q.task_done()
3518                 remaining.append(item)
3519         q.join()
3520         yield from remaining
3521
3522
3523 def windowed_complete(iterable, n):
3524     """
3525     Yield ``(beginning, middle, end)`` tuples, where:
3526
3527     * Each ``middle`` has *n* items from *iterable*
3528     * Each ``beginning`` has the items before the ones in ``middle``
3529     * Each ``end`` has the items after the ones in ``middle``
3530
3531     >>> iterable = range(7)
3532     >>> n = 3
3533     >>> for beginning, middle, end in windowed_complete(iterable, n):
3534     ...     print(beginning, middle, end)
3535     () (0, 1, 2) (3, 4, 5, 6)
3536     (0,) (1, 2, 3) (4, 5, 6)
3537     (0, 1) (2, 3, 4) (5, 6)
3538     (0, 1, 2) (3, 4, 5) (6,)
3539     (0, 1, 2, 3) (4, 5, 6) ()
3540
3541     Note that *n* must be at least 0 and most equal to the length of
3542     *iterable*.
3543
3544     This function will exhaust the iterable and may require significant
3545     storage.
3546     """
3547     if n < 0:
3548         raise ValueError('n must be >= 0')
3549
3550     seq = tuple(iterable)
3551     size = len(seq)
3552
3553     if n > size:
3554         raise ValueError('n must be <= len(seq)')
3555
3556     for i in range(size - n + 1):
3557         beginning = seq[:i]
3558         middle = seq[i : i + n]
3559         end = seq[i + n :]
3560         yield beginning, middle, end
3561
3562
3563 def all_unique(iterable, key=None):
3564     """
3565     Returns ``True`` if all the elements of *iterable* are unique (no two
3566     elements are equal).
3567
3568         >>> all_unique('ABCB')
3569         False
3570
3571     If a *key* function is specified, it will be used to make comparisons.
3572
3573         >>> all_unique('ABCb')
3574         True
3575         >>> all_unique('ABCb', str.lower)
3576         False
3577
3578     The function returns as soon as the first non-unique element is
3579     encountered. Iterables with a mix of hashable and unhashable items can
3580     be used, but the function will be slower for unhashable items.
3581     """
3582     seenset = set()
3583     seenset_add = seenset.add
3584     seenlist = []
3585     seenlist_add = seenlist.append
3586     for element in map(key, iterable) if key else iterable:
3587         try:
3588             if element in seenset:
3589                 return False
3590             seenset_add(element)
3591         except TypeError:
3592             if element in seenlist:
3593                 return False
3594             seenlist_add(element)
3595     return True
3596
3597
3598 def nth_product(index, *args):
3599     """Equivalent to ``list(product(*args))[index]``.
3600
3601     The products of *args* can be ordered lexicographically.
3602     :func:`nth_product` computes the product at sort position *index* without
3603     computing the previous products.
3604
3605         >>> nth_product(8, range(2), range(2), range(2), range(2))
3606         (1, 0, 0, 0)
3607
3608     ``IndexError`` will be raised if the given *index* is invalid.
3609     """
3610     pools = list(map(tuple, reversed(args)))
3611     ns = list(map(len, pools))
3612
3613     c = reduce(mul, ns)
3614
3615     if index < 0:
3616         index += c
3617
3618     if not 0 <= index < c:
3619         raise IndexError
3620
3621     result = []
3622     for pool, n in zip(pools, ns):
3623         result.append(pool[index % n])
3624         index //= n
3625
3626     return tuple(reversed(result))
3627
3628
3629 def nth_permutation(iterable, r, index):
3630     """Equivalent to ``list(permutations(iterable, r))[index]```
3631
3632     The subsequences of *iterable* that are of length *r* where order is
3633     important can be ordered lexicographically. :func:`nth_permutation`
3634     computes the subsequence at sort position *index* directly, without
3635     computing the previous subsequences.
3636
3637         >>> nth_permutation('ghijk', 2, 5)
3638         ('h', 'i')
3639
3640     ``ValueError`` will be raised If *r* is negative or greater than the length
3641     of *iterable*.
3642     ``IndexError`` will be raised if the given *index* is invalid.
3643     """
3644     pool = list(iterable)
3645     n = len(pool)
3646
3647     if r is None or r == n:
3648         r, c = n, factorial(n)
3649     elif not 0 <= r < n:
3650         raise ValueError
3651     else:
3652         c = factorial(n) // factorial(n - r)
3653
3654     if index < 0:
3655         index += c
3656
3657     if not 0 <= index < c:
3658         raise IndexError
3659
3660     if c == 0:
3661         return tuple()
3662
3663     result = [0] * r
3664     q = index * factorial(n) // c if r < n else index
3665     for d in range(1, n + 1):
3666         q, i = divmod(q, d)
3667         if 0 <= n - d < r:
3668             result[n - d] = i
3669         if q == 0:
3670             break
3671
3672     return tuple(map(pool.pop, result))
3673
3674
3675 def value_chain(*args):
3676     """Yield all arguments passed to the function in the same order in which
3677     they were passed. If an argument itself is iterable then iterate over its
3678     values.
3679
3680         >>> list(value_chain(1, 2, 3, [4, 5, 6]))
3681         [1, 2, 3, 4, 5, 6]
3682
3683     Binary and text strings are not considered iterable and are emitted
3684     as-is:
3685
3686         >>> list(value_chain('12', '34', ['56', '78']))
3687         ['12', '34', '56', '78']
3688
3689
3690     Multiple levels of nesting are not flattened.
3691
3692     """
3693     for value in args:
3694         if isinstance(value, (str, bytes)):
3695             yield value
3696             continue
3697         try:
3698             yield from value
3699         except TypeError:
3700             yield value
3701
3702
3703 def product_index(element, *args):
3704     """Equivalent to ``list(product(*args)).index(element)``
3705
3706     The products of *args* can be ordered lexicographically.
3707     :func:`product_index` computes the first index of *element* without
3708     computing the previous products.
3709
3710         >>> product_index([8, 2], range(10), range(5))
3711         42
3712
3713     ``ValueError`` will be raised if the given *element* isn't in the product
3714     of *args*.
3715     """
3716     index = 0
3717
3718     for x, pool in zip_longest(element, args, fillvalue=_marker):
3719         if x is _marker or pool is _marker:
3720             raise ValueError('element is not a product of args')
3721
3722         pool = tuple(pool)
3723         index = index * len(pool) + pool.index(x)
3724
3725     return index
3726
3727
3728 def combination_index(element, iterable):
3729     """Equivalent to ``list(combinations(iterable, r)).index(element)``
3730
3731     The subsequences of *iterable* that are of length *r* can be ordered
3732     lexicographically. :func:`combination_index` computes the index of the
3733     first *element*, without computing the previous combinations.
3734
3735         >>> combination_index('adf', 'abcdefg')
3736         10
3737
3738     ``ValueError`` will be raised if the given *element* isn't one of the
3739     combinations of *iterable*.
3740     """
3741     element = enumerate(element)
3742     k, y = next(element, (None, None))
3743     if k is None:
3744         return 0
3745
3746     indexes = []
3747     pool = enumerate(iterable)
3748     for n, x in pool:
3749         if x == y:
3750             indexes.append(n)
3751             tmp, y = next(element, (None, None))
3752             if tmp is None:
3753                 break
3754             else:
3755                 k = tmp
3756     else:
3757         raise ValueError('element is not a combination of iterable')
3758
3759     n, _ = last(pool, default=(n, None))
3760
3761     # Python versiosn below 3.8 don't have math.comb
3762     index = 1
3763     for i, j in enumerate(reversed(indexes), start=1):
3764         j = n - j
3765         if i <= j:
3766             index += factorial(j) // (factorial(i) * factorial(j - i))
3767
3768     return factorial(n + 1) // (factorial(k + 1) * factorial(n - k)) - index
3769
3770
3771 def permutation_index(element, iterable):
3772     """Equivalent to ``list(permutations(iterable, r)).index(element)```
3773
3774     The subsequences of *iterable* that are of length *r* where order is
3775     important can be ordered lexicographically. :func:`permutation_index`
3776     computes the index of the first *element* directly, without computing
3777     the previous permutations.
3778
3779         >>> permutation_index([1, 3, 2], range(5))
3780         19
3781
3782     ``ValueError`` will be raised if the given *element* isn't one of the
3783     permutations of *iterable*.
3784     """
3785     index = 0
3786     pool = list(iterable)
3787     for i, x in zip(range(len(pool), -1, -1), element):
3788         r = pool.index(x)
3789         index = index * i + r
3790         del pool[r]
3791
3792     return index
3793
3794
3795 class countable:
3796     """Wrap *iterable* and keep a count of how many items have been consumed.
3797
3798     The ``items_seen`` attribute starts at ``0`` and increments as the iterable
3799     is consumed:
3800
3801         >>> iterable = map(str, range(10))
3802         >>> it = countable(iterable)
3803         >>> it.items_seen
3804         0
3805         >>> next(it), next(it)
3806         ('0', '1')
3807         >>> list(it)
3808         ['2', '3', '4', '5', '6', '7', '8', '9']
3809         >>> it.items_seen
3810         10
3811     """
3812
3813     def __init__(self, iterable):
3814         self._it = iter(iterable)
3815         self.items_seen = 0
3816
3817     def __iter__(self):
3818         return self
3819
3820     def __next__(self):
3821         item = next(self._it)
3822         self.items_seen += 1
3823
3824         return item