ValueError: установка элемента массива с последовательностью в Python
Во-первых, вот мой код:
"""Softmax."""
scores = [3.0, 1.0, 0.2]
import numpy as np
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
num = np.exp(x)
score_len = len(x)
y = np.array([0]*score_len)
sum_n = np.sum(num)
#print sum_n
for index in range(1,score_len):
y[index] = (num[index])/sum_n
return y
print(softmax(scores))
Ошибка появляется в строке:
y[index] = (num[index])/sum_n
Я запускаю код с:
# Plot softmax curves
import matplotlib.pyplot as plt
x = np.arange(-2.0, 6.0, 0.1)
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])
plt.plot(x, softmax(scores).T, linewidth=2)
plt.show()
Что именно здесь происходит не так?
4 ответа
Просто редактирование print
Заявление как "отладчик" выявляет происходящее:
import numpy as np
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
num = np.exp(x)
score_len = len(x)
y = np.array([0]*score_len)
sum_n = np.sum(num)
#print sum_n
for index in range(1,score_len):
print((num[index])/sum_n)
y[index] = (num[index])/sum_n
return y
x = np.arange(-2.0, 6.0, 0.1)
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])
softmax(scores).T
это печатает
[ 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504 0.00065504 0.00065504 0.00065504 0.00065504
0.00065504 0.00065504]
поэтому вы пытаетесь присвоить этот массив одному элементу другого массива. Что не разрешено!
Есть несколько способов сделать это так, чтобы это работало. Просто меняется
y = np.array([0]*score_len)
чтобы многомерный массив работал:
y = np.zeros(score.shape)
Это должно сработать, но я не уверен, что это то, что вы хотели.
РЕДАКТИРОВАТЬ:
Кажется, вы не хотели многомерного ввода, поэтому вам просто нужно изменить:
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])
в
scores = np.hstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])
проверьте форму этих массивов, напечатав scores.shape
действительно поможет вам найти такие ошибки самостоятельно. Первый укладывается вдоль первой оси (vstack), а hstack - по нулевой (что вам нужно)
Это плохой способ инициализации массива:
y = np.array([0]*score_len)
лучше сделать что-то вроде
y = np.zeros((n,m))
где n
а также m
2 измерения конечного продукта. Я предполагаю из вашего другого вопроса, что вы хотите y
быть 2d (ведь вы делаете .T
на это после).
Обратите внимание на форму scores
что вы передаете в функцию. И при повторении включайте :
, Это может быть необязательным, но вам нужно, чтобы размеры были в вашем уме:
y[index,:] = (num[index,:])/sum_n
В итоге - сфокусируйтесь на понимании того, как работать с многомерными массивами, как их создавать и как их индексировать, как работать с ними без итераций и как правильно итерировать при необходимости.
Несоответствия в построении массива могут вызвать такие проблемы, например
[[1,2,3,4], [2,3], [1],[1,2,3,4]]
это плохой примерный массив.
Это должно работать отлично и быстро
scores = [3.0, 1.0, 0.2]
import numpy as np
def softmax(x):
num = np.exp(x)
score_len = len(x)
y = np.zeros(score_len, object) # or => np.asarray([None]*score_len)
sum_n = np.sum(num)
for i in range(score_len):
y[i] = num[i] / sum_n
return y
print(softmax(scores))
x = np.arange(-2.0, 6.0, 0.1)
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])
printout = softmax(scores).T
print(printout)
Выход:
[0.8360188027814407 0.11314284146556011 0.050838355752999158]
[ array([ 3.26123038e-05, 3.60421698e-05, 3.98327578e-05,
4.40220056e-05, 4.86518403e-05, 5.37685990e-05,
5.94234919e-05, 6.56731151e-05, 7.25800169e-05,
8.02133239e-05, 8.86494329e-05, 9.79727751e-05,
1.08276662e-04, 1.19664218e-04, 1.32249413e-04,
1.46158206e-04, 1.61529798e-04, 1.78518035e-04,
1.97292941e-04, 2.18042421e-04, 2.40974142e-04,
2.66317614e-04, 2.94326482e-04, 3.25281069e-04,
3.59491177e-04, 3.97299194e-04, 4.39083515e-04,
4.85262332e-04, 5.36297817e-04, 5.92700751e-04,
6.55035633e-04, 7.23926331e-04, 8.00062328e-04,
8.84205618e-04, 9.77198335e-04, 1.07997118e-03,
1.19355274e-03, 1.31907978e-03, 1.45780861e-03,
1.61112768e-03, 1.78057146e-03, 1.96783579e-03,
2.17479489e-03, 2.40352006e-03, 2.65630048e-03,
2.93566604e-03, 3.24441273e-03, 3.58563059e-03,
3.96273465e-03, 4.37949910e-03, 4.84009504e-03,
5.34913227e-03, 5.91170543e-03, 6.53344491e-03,
7.22057331e-03, 7.97996764e-03, 8.81922816e-03,
9.74675448e-03, 1.07718296e-02, 1.19047128e-02,
1.31567424e-02, 1.45404491e-02, 1.60696814e-02,
1.77597446e-02, 1.96275532e-02, 2.16918010e-02,
2.39731477e-02, 2.64944256e-02, 2.92808687e-02,
3.23603645e-02, 3.57637337e-02, 3.95250385e-02,
4.36819230e-02, 4.82759910e-02, 5.33532213e-02,
5.89644285e-02, 6.51657716e-02, 7.20193157e-02,
7.95936532e-02, 8.79645908e-02])
array([ 0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504,
0.00065504, 0.00065504, 0.00065504, 0.00065504, 0.00065504])
array([ 0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433,
0.00029433, 0.00029433, 0.00029433, 0.00029433, 0.00029433])]