521abd7c2ca633f90a5ba13a8060c5c3d0c32205
[SubU] /
1 """Imported from the recipes section of the itertools documentation.
2
3 All functions taken from the recipes section of the itertools library docs
4 [1]_.
5 Some backward-compatible usability improvements have been made.
6
7 .. [1] http://docs.python.org/library/itertools.html#recipes
8
9 """
10 import warnings
11 from collections import deque
12 from itertools import (
13     chain,
14     combinations,
15     count,
16     cycle,
17     groupby,
18     islice,
19     repeat,
20     starmap,
21     tee,
22     zip_longest,
23 )
24 import operator
25 from random import randrange, sample, choice
26
27 __all__ = [
28     'all_equal',
29     'consume',
30     'convolve',
31     'dotproduct',
32     'first_true',
33     'flatten',
34     'grouper',
35     'iter_except',
36     'ncycles',
37     'nth',
38     'nth_combination',
39     'padnone',
40     'pad_none',
41     'pairwise',
42     'partition',
43     'powerset',
44     'prepend',
45     'quantify',
46     'random_combination_with_replacement',
47     'random_combination',
48     'random_permutation',
49     'random_product',
50     'repeatfunc',
51     'roundrobin',
52     'tabulate',
53     'tail',
54     'take',
55     'unique_everseen',
56     'unique_justseen',
57 ]
58
59
60 def take(n, iterable):
61     """Return first *n* items of the iterable as a list.
62
63         >>> take(3, range(10))
64         [0, 1, 2]
65
66     If there are fewer than *n* items in the iterable, all of them are
67     returned.
68
69         >>> take(10, range(3))
70         [0, 1, 2]
71
72     """
73     return list(islice(iterable, n))
74
75
76 def tabulate(function, start=0):
77     """Return an iterator over the results of ``func(start)``,
78     ``func(start + 1)``, ``func(start + 2)``...
79
80     *func* should be a function that accepts one integer argument.
81
82     If *start* is not specified it defaults to 0. It will be incremented each
83     time the iterator is advanced.
84
85         >>> square = lambda x: x ** 2
86         >>> iterator = tabulate(square, -3)
87         >>> take(4, iterator)
88         [9, 4, 1, 0]
89
90     """
91     return map(function, count(start))
92
93
94 def tail(n, iterable):
95     """Return an iterator over the last *n* items of *iterable*.
96
97     >>> t = tail(3, 'ABCDEFG')
98     >>> list(t)
99     ['E', 'F', 'G']
100
101     """
102     return iter(deque(iterable, maxlen=n))
103
104
105 def consume(iterator, n=None):
106     """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
107     entirely.
108
109     Efficiently exhausts an iterator without returning values. Defaults to
110     consuming the whole iterator, but an optional second argument may be
111     provided to limit consumption.
112
113         >>> i = (x for x in range(10))
114         >>> next(i)
115         0
116         >>> consume(i, 3)
117         >>> next(i)
118         4
119         >>> consume(i)
120         >>> next(i)
121         Traceback (most recent call last):
122           File "<stdin>", line 1, in <module>
123         StopIteration
124
125     If the iterator has fewer items remaining than the provided limit, the
126     whole iterator will be consumed.
127
128         >>> i = (x for x in range(3))
129         >>> consume(i, 5)
130         >>> next(i)
131         Traceback (most recent call last):
132           File "<stdin>", line 1, in <module>
133         StopIteration
134
135     """
136     # Use functions that consume iterators at C speed.
137     if n is None:
138         # feed the entire iterator into a zero-length deque
139         deque(iterator, maxlen=0)
140     else:
141         # advance to the empty slice starting at position n
142         next(islice(iterator, n, n), None)
143
144
145 def nth(iterable, n, default=None):
146     """Returns the nth item or a default value.
147
148     >>> l = range(10)
149     >>> nth(l, 3)
150     3
151     >>> nth(l, 20, "zebra")
152     'zebra'
153
154     """
155     return next(islice(iterable, n, None), default)
156
157
158 def all_equal(iterable):
159     """
160     Returns ``True`` if all the elements are equal to each other.
161
162         >>> all_equal('aaaa')
163         True
164         >>> all_equal('aaab')
165         False
166
167     """
168     g = groupby(iterable)
169     return next(g, True) and not next(g, False)
170
171
172 def quantify(iterable, pred=bool):
173     """Return the how many times the predicate is true.
174
175     >>> quantify([True, False, True])
176     2
177
178     """
179     return sum(map(pred, iterable))
180
181
182 def pad_none(iterable):
183     """Returns the sequence of elements and then returns ``None`` indefinitely.
184
185         >>> take(5, pad_none(range(3)))
186         [0, 1, 2, None, None]
187
188     Useful for emulating the behavior of the built-in :func:`map` function.
189
190     See also :func:`padded`.
191
192     """
193     return chain(iterable, repeat(None))
194
195
196 padnone = pad_none
197
198
199 def ncycles(iterable, n):
200     """Returns the sequence elements *n* times
201
202     >>> list(ncycles(["a", "b"], 3))
203     ['a', 'b', 'a', 'b', 'a', 'b']
204
205     """
206     return chain.from_iterable(repeat(tuple(iterable), n))
207
208
209 def dotproduct(vec1, vec2):
210     """Returns the dot product of the two iterables.
211
212     >>> dotproduct([10, 10], [20, 20])
213     400
214
215     """
216     return sum(map(operator.mul, vec1, vec2))
217
218
219 def flatten(listOfLists):
220     """Return an iterator flattening one level of nesting in a list of lists.
221
222         >>> list(flatten([[0, 1], [2, 3]]))
223         [0, 1, 2, 3]
224
225     See also :func:`collapse`, which can flatten multiple levels of nesting.
226
227     """
228     return chain.from_iterable(listOfLists)
229
230
231 def repeatfunc(func, times=None, *args):
232     """Call *func* with *args* repeatedly, returning an iterable over the
233     results.
234
235     If *times* is specified, the iterable will terminate after that many
236     repetitions:
237
238         >>> from operator import add
239         >>> times = 4
240         >>> args = 3, 5
241         >>> list(repeatfunc(add, times, *args))
242         [8, 8, 8, 8]
243
244     If *times* is ``None`` the iterable will not terminate:
245
246         >>> from random import randrange
247         >>> times = None
248         >>> args = 1, 11
249         >>> take(6, repeatfunc(randrange, times, *args))  # doctest:+SKIP
250         [2, 4, 8, 1, 8, 4]
251
252     """
253     if times is None:
254         return starmap(func, repeat(args))
255     return starmap(func, repeat(args, times))
256
257
258 def _pairwise(iterable):
259     """Returns an iterator of paired items, overlapping, from the original
260
261     >>> take(4, pairwise(count()))
262     [(0, 1), (1, 2), (2, 3), (3, 4)]
263
264     On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.
265
266     """
267     a, b = tee(iterable)
268     next(b, None)
269     yield from zip(a, b)
270
271
272 try:
273     from itertools import pairwise as itertools_pairwise
274 except ImportError:
275     pairwise = _pairwise
276 else:
277
278     def pairwise(iterable):
279         yield from itertools_pairwise(iterable)
280
281     pairwise.__doc__ = _pairwise.__doc__
282
283
284 def grouper(iterable, n, fillvalue=None):
285     """Collect data into fixed-length chunks or blocks.
286
287     >>> list(grouper('ABCDEFG', 3, 'x'))
288     [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
289
290     """
291     if isinstance(iterable, int):
292         warnings.warn(
293             "grouper expects iterable as first parameter", DeprecationWarning
294         )
295         n, iterable = iterable, n
296     args = [iter(iterable)] * n
297     return zip_longest(fillvalue=fillvalue, *args)
298
299
300 def roundrobin(*iterables):
301     """Yields an item from each iterable, alternating between them.
302
303         >>> list(roundrobin('ABC', 'D', 'EF'))
304         ['A', 'D', 'E', 'B', 'F', 'C']
305
306     This function produces the same output as :func:`interleave_longest`, but
307     may perform better for some inputs (in particular when the number of
308     iterables is small).
309
310     """
311     # Recipe credited to George Sakkis
312     pending = len(iterables)
313     nexts = cycle(iter(it).__next__ for it in iterables)
314     while pending:
315         try:
316             for next in nexts:
317                 yield next()
318         except StopIteration:
319             pending -= 1
320             nexts = cycle(islice(nexts, pending))
321
322
323 def partition(pred, iterable):
324     """
325     Returns a 2-tuple of iterables derived from the input iterable.
326     The first yields the items that have ``pred(item) == False``.
327     The second yields the items that have ``pred(item) == True``.
328
329         >>> is_odd = lambda x: x % 2 != 0
330         >>> iterable = range(10)
331         >>> even_items, odd_items = partition(is_odd, iterable)
332         >>> list(even_items), list(odd_items)
333         ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
334
335     If *pred* is None, :func:`bool` is used.
336
337         >>> iterable = [0, 1, False, True, '', ' ']
338         >>> false_items, true_items = partition(None, iterable)
339         >>> list(false_items), list(true_items)
340         ([0, False, ''], [1, True, ' '])
341
342     """
343     if pred is None:
344         pred = bool
345
346     evaluations = ((pred(x), x) for x in iterable)
347     t1, t2 = tee(evaluations)
348     return (
349         (x for (cond, x) in t1 if not cond),
350         (x for (cond, x) in t2 if cond),
351     )
352
353
354 def powerset(iterable):
355     """Yields all possible subsets of the iterable.
356
357         >>> list(powerset([1, 2, 3]))
358         [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
359
360     :func:`powerset` will operate on iterables that aren't :class:`set`
361     instances, so repeated elements in the input will produce repeated elements
362     in the output. Use :func:`unique_everseen` on the input to avoid generating
363     duplicates:
364
365         >>> seq = [1, 1, 0]
366         >>> list(powerset(seq))
367         [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
368         >>> from more_itertools import unique_everseen
369         >>> list(powerset(unique_everseen(seq)))
370         [(), (1,), (0,), (1, 0)]
371
372     """
373     s = list(iterable)
374     return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
375
376
377 def unique_everseen(iterable, key=None):
378     """
379     Yield unique elements, preserving order.
380
381         >>> list(unique_everseen('AAAABBBCCDAABBB'))
382         ['A', 'B', 'C', 'D']
383         >>> list(unique_everseen('ABBCcAD', str.lower))
384         ['A', 'B', 'C', 'D']
385
386     Sequences with a mix of hashable and unhashable items can be used.
387     The function will be slower (i.e., `O(n^2)`) for unhashable items.
388
389     Remember that ``list`` objects are unhashable - you can use the *key*
390     parameter to transform the list to a tuple (which is hashable) to
391     avoid a slowdown.
392
393         >>> iterable = ([1, 2], [2, 3], [1, 2])
394         >>> list(unique_everseen(iterable))  # Slow
395         [[1, 2], [2, 3]]
396         >>> list(unique_everseen(iterable, key=tuple))  # Faster
397         [[1, 2], [2, 3]]
398
399     Similary, you may want to convert unhashable ``set`` objects with
400     ``key=frozenset``. For ``dict`` objects,
401     ``key=lambda x: frozenset(x.items())`` can be used.
402
403     """
404     seenset = set()
405     seenset_add = seenset.add
406     seenlist = []
407     seenlist_add = seenlist.append
408     use_key = key is not None
409
410     for element in iterable:
411         k = key(element) if use_key else element
412         try:
413             if k not in seenset:
414                 seenset_add(k)
415                 yield element
416         except TypeError:
417             if k not in seenlist:
418                 seenlist_add(k)
419                 yield element
420
421
422 def unique_justseen(iterable, key=None):
423     """Yields elements in order, ignoring serial duplicates
424
425     >>> list(unique_justseen('AAAABBBCCDAABBB'))
426     ['A', 'B', 'C', 'D', 'A', 'B']
427     >>> list(unique_justseen('ABBCcAD', str.lower))
428     ['A', 'B', 'C', 'A', 'D']
429
430     """
431     return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
432
433
434 def iter_except(func, exception, first=None):
435     """Yields results from a function repeatedly until an exception is raised.
436
437     Converts a call-until-exception interface to an iterator interface.
438     Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
439     to end the loop.
440
441         >>> l = [0, 1, 2]
442         >>> list(iter_except(l.pop, IndexError))
443         [2, 1, 0]
444
445     """
446     try:
447         if first is not None:
448             yield first()
449         while 1:
450             yield func()
451     except exception:
452         pass
453
454
455 def first_true(iterable, default=None, pred=None):
456     """
457     Returns the first true value in the iterable.
458
459     If no true value is found, returns *default*
460
461     If *pred* is not None, returns the first item for which
462     ``pred(item) == True`` .
463
464         >>> first_true(range(10))
465         1
466         >>> first_true(range(10), pred=lambda x: x > 5)
467         6
468         >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
469         'missing'
470
471     """
472     return next(filter(pred, iterable), default)
473
474
475 def random_product(*args, repeat=1):
476     """Draw an item at random from each of the input iterables.
477
478         >>> random_product('abc', range(4), 'XYZ')  # doctest:+SKIP
479         ('c', 3, 'Z')
480
481     If *repeat* is provided as a keyword argument, that many items will be
482     drawn from each iterable.
483
484         >>> random_product('abcd', range(4), repeat=2)  # doctest:+SKIP
485         ('a', 2, 'd', 3)
486
487     This equivalent to taking a random selection from
488     ``itertools.product(*args, **kwarg)``.
489
490     """
491     pools = [tuple(pool) for pool in args] * repeat
492     return tuple(choice(pool) for pool in pools)
493
494
495 def random_permutation(iterable, r=None):
496     """Return a random *r* length permutation of the elements in *iterable*.
497
498     If *r* is not specified or is ``None``, then *r* defaults to the length of
499     *iterable*.
500
501         >>> random_permutation(range(5))  # doctest:+SKIP
502         (3, 4, 0, 1, 2)
503
504     This equivalent to taking a random selection from
505     ``itertools.permutations(iterable, r)``.
506
507     """
508     pool = tuple(iterable)
509     r = len(pool) if r is None else r
510     return tuple(sample(pool, r))
511
512
513 def random_combination(iterable, r):
514     """Return a random *r* length subsequence of the elements in *iterable*.
515
516         >>> random_combination(range(5), 3)  # doctest:+SKIP
517         (2, 3, 4)
518
519     This equivalent to taking a random selection from
520     ``itertools.combinations(iterable, r)``.
521
522     """
523     pool = tuple(iterable)
524     n = len(pool)
525     indices = sorted(sample(range(n), r))
526     return tuple(pool[i] for i in indices)
527
528
529 def random_combination_with_replacement(iterable, r):
530     """Return a random *r* length subsequence of elements in *iterable*,
531     allowing individual elements to be repeated.
532
533         >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
534         (0, 0, 1, 2, 2)
535
536     This equivalent to taking a random selection from
537     ``itertools.combinations_with_replacement(iterable, r)``.
538
539     """
540     pool = tuple(iterable)
541     n = len(pool)
542     indices = sorted(randrange(n) for i in range(r))
543     return tuple(pool[i] for i in indices)
544
545
546 def nth_combination(iterable, r, index):
547     """Equivalent to ``list(combinations(iterable, r))[index]``.
548
549     The subsequences of *iterable* that are of length *r* can be ordered
550     lexicographically. :func:`nth_combination` computes the subsequence at
551     sort position *index* directly, without computing the previous
552     subsequences.
553
554         >>> nth_combination(range(5), 3, 5)
555         (0, 3, 4)
556
557     ``ValueError`` will be raised If *r* is negative or greater than the length
558     of *iterable*.
559     ``IndexError`` will be raised if the given *index* is invalid.
560     """
561     pool = tuple(iterable)
562     n = len(pool)
563     if (r < 0) or (r > n):
564         raise ValueError
565
566     c = 1
567     k = min(r, n - r)
568     for i in range(1, k + 1):
569         c = c * (n - k + i) // i
570
571     if index < 0:
572         index += c
573
574     if (index < 0) or (index >= c):
575         raise IndexError
576
577     result = []
578     while r:
579         c, n, r = c * r // n, n - 1, r - 1
580         while index >= c:
581             index -= c
582             c, n = c * (n - r) // n, n - 1
583         result.append(pool[-1 - n])
584
585     return tuple(result)
586
587
588 def prepend(value, iterator):
589     """Yield *value*, followed by the elements in *iterator*.
590
591         >>> value = '0'
592         >>> iterator = ['1', '2', '3']
593         >>> list(prepend(value, iterator))
594         ['0', '1', '2', '3']
595
596     To prepend multiple values, see :func:`itertools.chain`
597     or :func:`value_chain`.
598
599     """
600     return chain([value], iterator)
601
602
603 def convolve(signal, kernel):
604     """Convolve the iterable *signal* with the iterable *kernel*.
605
606         >>> signal = (1, 2, 3, 4, 5)
607         >>> kernel = [3, 2, 1]
608         >>> list(convolve(signal, kernel))
609         [3, 8, 14, 20, 26, 14, 5]
610
611     Note: the input arguments are not interchangeable, as the *kernel*
612     is immediately consumed and stored.
613
614     """
615     kernel = tuple(kernel)[::-1]
616     n = len(kernel)
617     window = deque([0], maxlen=n) * n
618     for x in chain(signal, repeat(0, n - 1)):
619         window.append(x)
620         yield sum(map(operator.mul, kernel, window))