Esta é uma dificuldade conceitual comum quando se está aprendendo a usar NumPy efetivamente. Normalmente, o processamento de dados em Python é melhor expresso em termos de iteradores , para manter baixo o uso de memória, maximizar oportunidades de paralelismo com o sistema de E / S e fornecer reutilização e combinação de partes de algoritmos .
Mas o NumPy transforma tudo isso de dentro para fora: a melhor abordagem é expressar o algoritmo como uma sequência de operações de matriz inteira , para minimizar a quantidade de tempo gasto no interpretador lento do Python e maximizar o quantidade de tempo gasto em rotinas NumPy compiladas rapidamente.
Aqui está a abordagem geral que eu tomo:
-
Mantenha a versão original da função (que você está certo de que está correta) para que você possa testá-la em relação às suas versões aprimoradas, tanto para correção quanto para velocidade.
-
Trabalhe de dentro para fora: isto é, comece com o loop mais interno e veja se pode ser vetorizado; depois, quando tiver feito isso, saia de um nível e continue.
-
Gastar muito tempo lendo a documentação do NumPy . Há muitas funções e operações lá e elas nem sempre são nomeadas de forma brilhante, então vale a pena conhecê-las. Em particular, se você se acha pensando, "se houvesse uma função que fez isso e aquilo", vale a pena gastar dez minutos procurando por ela. Geralmente está lá em algum lugar.
Não há substituto para a prática, então vou dar alguns exemplos de problemas. O objetivo de cada problema é reescrever a função para que seja totalmente vetorizado : isto é, para que seja uma seqüência de operações NumPy em matrizes inteiras, sem loops nativos de Python (no for
ou while
declarações, sem iteradores ou compreensões).
Problema 1
def sumproducts(x, y):
"""Return the sum of x[i] * y[j] for all pairs of indices i, j.
>>> sumproducts(np.arange(3000), np.arange(3000))
20236502250000
"""
result = 0
for i in range(len(x)):
for j in range(len(y)):
result += x[i] * y[j]
return result
Problema 2
def countlower(x, y):
"""Return the number of pairs i, j such that x[i] < y[j].
>>> countlower(np.arange(0, 200, 2), np.arange(40, 140))
4500
"""
result = 0
for i in range(len(x)):
for j in range(len(y)):
if x[i] < y[j]:
result += 1
return result
Problema 3
def cleanup(x, missing=-1, value=0):
"""Return an array that's the same as x, except that where x ==
missing, it has value instead.
>>> cleanup(np.arange(-3, 3), value=10)
... # doctest: +NORMALIZE_WHITESPACE
array([-3, -2, 10, 0, 1, 2])
"""
result = []
for i in range(len(x)):
if x[i] == missing:
result.append(value)
else:
result.append(x[i])
return np.array(result)
Spoilers abaixo. Você obterá os melhores resultados se tiver sucesso antes de olhar para as minhas soluções!
Resposta 1
np.sum(x) * np.sum(y)
Resposta 2
np.sum(np.searchsorted(np.sort(x), y))
Resposta 3
np.where(x == missing, value, x)