Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more.
5 ответов

Не удалось установить jaxlib

Я пытаюсь установить jaxlib на свои окна 10 с помощью следующей команды, которую я нашел в документации. pip install jaxlib Он показывает следующую ошибку Collecting jaxlib Could not find a version that satisfies the requirement jaxlib (from version…
26 июн '20 в 01:44
1 ответ

Может ли pytorch оптимизировать последовательные операции (например, граф тензорного потока или JAX jit)?

Изначально tenorflow и pytorch имели принципиальное отличие: tenorflow основан на вычислительном графе. Построение этого графика и его оценка в сеансе - два отдельных шага. В процессе использования график не меняется, что позволяет проводить оптимиз…
28 окт '19 в 21:18
1 ответ

jax vmap: обеспечить правильную форму

Я использую vmap для векторизации частей моего кода. Вот минимальный пример до векторизации: dim = 2 def sum(x): a = np.ones((dim,)) return np.dot(x, a) num_samples = 100 samples = np.ones((num_samples, dim)) sum(samples[0]) # 2 с vmap: sum = vmap(s…
28 окт '19 в 15:55
1 ответ

JAX: время jit функции становится суперлинейной с доступом к памяти функцией

Вот простой пример, который численно интегрирует произведение двух гауссовских PDF-файлов. Один из гауссианов фиксирован, среднее всегда равно 0. Среднее значение другого гауссиана изменяется: import time import jax.numpy as np from jax import jit f…
27 ноя '19 в 14:00
0 ответов

Аналог tf.depthwise_conv2d с использованием Jax jax.lax.conv

Я переношу код с Tensorflow на Jax и сталкиваюсь со следующими трудностями: У меня есть два массива: R и S. У нас есть: R.shape (10,201,11) а также S.shape (61,11) Мне нужно свернуть каждый S[:,i] с соответствующим R[j,:,i] для всех j из 0:9, в резу…
10 мар '20 в 01:27
1 ответ

vmap над списком в jax

Используя jax, я пытаюсь вычислить градиенты для каждого образца, обработать их, а затем привести их в нормальную форму, чтобы вычислить нормальное обновление параметров. Мой рабочий код выглядит так differentiate_per_sample = jit(vmap(grad(loss), i…
14 май '20 в 02:49
0 ответов

Как использовать JAX и autorgrad для обратного распространения ошибки, созданного с помощью сокращения?

Некоторое время назад я построил numpyоснованная на машинном обучении "библиотека" как школьное домашнее задание. Он был основан исключительно наnumpy, но теперь я хочу перевести его на JAX. У меня возникли проблемы с настройкой процесса обратного р…
19 авг '20 в 03:15
1 ответ

Якобианский определитель векторной функции с Python JAX/Autograd

У меня есть функция, которая отображает векторы на векторы и я хочу вычислить его определитель Якоби , где якобиан определяется как . Поскольку я могу использовать numpy.linalg.det , чтобы вычислить определитель, мне просто нужна матрица Якоби. Я зн…
14 янв '20 в 18:35
2 ответа

Проблемы с ограничениями JIT и Numpy Jax

Недавно я начал экспериментировать с интересной библиотекой Python Jax, которая содержит улучшенный Numpy, а также автоматический дифференциатор. Я хотел попытаться создать грубый "дифференцируемый рендерер", написав шейдер и функцию потерь на pytho…
06 мар '20 в 04:16
1 ответ

Найти градиент функции: Sympy vs. Jax

У меня есть функция Black_Cox() который вызывает другие функции, как показано ниже: import numpy as np from scipyimport stats # Parameters D = 100 r = 0.05 γ = 0.1 # Normal CDF N = lambda x: stats.norm.cdf(x) H = lambda V, T, L, σ: np.exp(-r*T) * N(…
20 апр '20 в 17:06
0 ответов

Функционализация дополнительных вычислений с прогнозированием

Я работаю с Jax и Stax. Базовый цикл прямой связи для сети stax выглядит примерно так: def apply_fun(params, inputs, **kwargs): for fun, param, rng in zip(apply_funs, params): inputs = fun(param, inputs, **kwargs) return inputs Если вы не знакомы с …
1 ответ

Получить и опубликовать вызов API в java с базовой аутентификацией

Я хочу позвонить GET а также POST API в java без использования каких-либо framework. Мне нужно использовать базовую аутентификацию. Может ли кто-нибудь помочь мне с какой-нибудь учебной ссылкой. В Google я нашел код только вspring framework, но я не…
25 июн '20 в 00:40
2 ответа

Какой лучший способ вычислить точечные произведения по строкам (или по осям) с помощью jax?

У меня есть два числовых массива формы (N, M). Я хотел бы вычислить точечный продукт по строкам. Т.е. создайте массив формы (N,), такой, что n-я строка является скалярным произведением n-й строки из каждого массива. Я знаю о numpy inner1dметод. Как …
20 апр '20 в 05:54
0 ответов

Эффективный способ вычисления якобиана x якобиана.T

Предполагать J - якобиан некоторой функции fпо некоторым параметрам. Есть ли эффективные способы (в PyTorch или, возможно, Jax) иметь функцию, которая принимает два входа (x1 а также x2) и вычисляет J(x1)*J(x2).transpose() без создания всегоJ матриц…
24 авг '20 в 13:21
1 ответ

Сверточная нейронная сеть Google JAX 1D

Я пытаюсь реализовать 1D сверточную нейронную сеть в Google Jax с помощью stax.GeneralConv() (https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html). У меня есть одномерный входной массив с 18 и выходной массив с 6 записями. Я хоч…
13 июн '20 в 11:17
0 ответов

MNIST DC-GAN Все градиенты нули

Я пытаюсь создать DC-GAN для MNIST на платформе Flax, используя пример TF в качестве ориентира. Сама сеть технически работает, но ни генератор, ни дискриминатор не обновляются, поскольку их градиенты всегда равны нулю. Я уже убедился, что веса иници…
1 ответ

Каков самый быстрый способ выбора подмножества матрицы JAX?

Допустим, у меня есть 2D-матрица, и я хочу отобразить ее значения в виде гистограммы. Для этого мне нужно сделать что-то вроде: list_1d = matrix_2d.reshape((-1,)).tolist() А затем используйте список для построения гистограммы. Пока все хорошо, прост…
31 окт '20 в 00:49
0 ответов

Работа с Google JAX на сервере gunicorn/flask

Я хочу обслуживать приложение, которое обрабатывает данные во фреймворке Googles JAX с помощью flask и gunicorn. Если запустить внутри колбы, все работает нормально. Как только я запускаю приложение в gunicorn, каждая часть, связанная с jax, приводи…
02 ноя '20 в 15:40
1 ответ

Нетранзитивное создание подклассов с помощью numpy и jax

Мой вопрос простой: >>> isinstance(x, jax.numpy.ndarray) True >>> issubclass(jax.numpy.ndarray, numpy.ndarray) True >>> isinstance(x, numpy.ndarray) False ? А теперь я побеспокоюсь, чтобы SE примет мой разумный вопрос.
03 ноя '20 в 02:26
1 ответ

Сравнение двух подходов возведения в степень элементов матрицы

У меня есть два подхода к возведению в степень матрицы в jnp = jax.numpy. Простой: jnp.exp(-X/reg) И с некоторыми дополнительными действиями: def exp_reg(X, reg): K = jnp.empty_like(X) K = jnp.divide(X, -reg) return jnp.exp(K) Однако когда я их прот…
04 ноя '20 в 21:00