Быстрый код, чтобы определить, имеют ли любые два подмножества столбцов одинаковую сумму
Для заданных n и m я итерирую по всем n по m частичным циркулянтным матрицам с записями, которые равны либо 0, либо 1. Я хочу выяснить, существует ли такая матрица, что нет двух подмножеств столбцов, дающих одинаковую сумму. Здесь, когда мы добавляем столбцы, мы просто делаем это поэлементно. Мой текущий код использует ограничение программирования через ortools. Однако это не так быстро, как хотелось бы. Для n = 7 и m = 12 это занимает более 3 минут, а для n = 10, m = 18, оно не заканчивается, даже если нужно рассмотреть только 2^18 = 262144 различных матриц. Вот мой код
from scipy.linalg import circulant
import numpy as np
import itertools
from ortools.constraint_solver import pywrapcp as cs
n = 7
m = 12
def isdetecting(matrix):
X = np.array([solver.IntVar(values) for i in range(matrix.shape[1])])
X1 = X.tolist()
for row in matrix:
x = X[row].tolist()
solver.Add(solver.Sum(x) == 0)
db = solver.Phase(X1, solver.INT_VAR_DEFAULT, solver.INT_VALUE_DEFAULT)
solver.NewSearch(db)
count = 0
while (solver.NextSolution() and count < 2):
solution = [x.Value() for x in X1]
count += 1
solver.EndSearch()
if (count < 2):
return True
values = [-1,0,1]
solver = cs.Solver("scip")
for row in itertools.product([0,1],repeat = m):
M = np.array(circulant(row)[0:n], dtype=bool)
if isdetecting(M):
print M.astype(int)
break
Можно ли решить эту проблему достаточно быстро, чтобы можно было решить n = 10, m = 18?
2 ответа
Одной из проблем является то, что вы объявляете переменную "решатель" глобально, и кажется, что это может сбить с толку or-tools, чтобы использовать его много раз. Если переместить его внутрь "isdetect", то проблема (7,12) решается гораздо быстрее, примерно за 7 секунд (по сравнению с 2:51 минутами для исходной модели). Я не проверял это для большей проблемы, все же.
Кроме того, было бы неплохо проверить различные надписи (вместо solver.INT_VAR_DEFAULT и solver.INT_VALUE_DEFAULT), хотя двоичное значение, как правило, не очень чувствительно к разным надписям. Смотрите код для другой маркировки.
def isdetecting(matrix):
solver = cs.Solver("scip") # <----
X = np.array([solver.IntVar(values) for i in range(matrix.shape[1])])
X1 = X.tolist()
for row in matrix:
x = X[row].tolist()
solver.Add(solver.Sum(x) == 0)
# db = solver.Phase(X1, solver.INT_VAR_DEFAULT, solver.INT_VALUE_DEFAULT)
db = solver.Phase(X1, solver.CHOOSE_FIRST_UNBOUND, solver.ASSIGN_CENTER_VALUE)
solver.NewSearch(db)
count = 0
while (solver.NextSolution() and count < 2):
solution = [x.Value() for x in X1]
count += 1
solver.EndSearch()
if (count < 2):
print "FOUND"
return True
Изменить: Вот ограничения для удаления всех 0 решений, как указано в комментариях. Что я знаю, это требует отдельного списка. Теперь это занимает немного больше времени (10,4 с против 7 с).
X1Abs = [solver.IntVar(values, 'X1Abs[%i]' % i) for i in range(X1_len)]
for i in range(X1_len):
solver.Add(X1Abs[i] == abs(X1[i]))
solver.Add(solver.Sum(X1Abs) > 0)
Примерно так я и имел в виду. Я бы оценил время работы параметров командной строки 10 18 на моей машине менее чем за 8 часов.
public class Search {
public static void main(String[] args) {
int n = Integer.parseInt(args[0]);
int m = Integer.parseInt(args[1]);
int row = search(n, m);
if (row >= 0) {
printRow(m, row);
}
}
private static int search(int n, int m) {
if (n < 0 || m < n || m >= 31 || powOverflows(m + 1, n)) {
throw new IllegalArgumentException();
}
long[] column = new long[m];
long[] sums = new long[1 << m];
int row = 1 << m;
while (row-- > 0) {
System.err.println(row);
for (int j = 0; j < m; j++) {
column[j] = 0;
for (int i = 0; i < n; i++) {
column[j] = (column[j] * (m + 1)) + ((row >> ((i + j) % m)) & 1);
}
}
for (int subset = 0; subset < (1 << m); subset++) {
long sum = 0;
for (int j = 0; j < m; j++) {
if (((subset >> j) & 1) == 1) {
sum += column[j];
}
}
sums[subset] = sum;
}
java.util.Arrays.sort(sums);
boolean duplicate = false;
for (int k = 1; k < (1 << m); k++) {
if (sums[k - 1] == sums[k]) {
duplicate = true;
break;
}
}
if (!duplicate) {
break;
}
}
return row;
}
private static boolean powOverflows(long b, int e) {
if (b <= 0 || e < 0) {
throw new IllegalArgumentException();
}
if (e == 0) {
return false;
}
long max = Long.MAX_VALUE;
while (e > 1) {
if (b > Integer.MAX_VALUE) {
return true;
}
if ((e & 1) == 1) {
max /= b;
}
b *= b;
e >>= 1;
}
return b > max;
}
private static void printRow(int m, int row) {
for (int j = 0; j < m; j++) {
System.out.print((row >> j) & 1);
}
System.out.println();
}
}