Flax намного медленнее, чем чистый Jax для нейронных сетей?
для проекта я пытаюсь написать очень простой пример MLP, но я заметил, что реализация во льне примерно в 20 раз медленнее, чем чистая реализация jax. Что я здесь делаю неправильно?
import time
import jax.numpy as np
from jax import random, jit, vmap, jacfwd
from jax.nn import sigmoid, softplus
import jax
from flax import linen as nn
import numpy as np
from typing import Sequence
def MLP(layers):
def init(rng_key):
def init_layer(key, d_in, d_out):
k1, k2 = random.split(key)
W = random.normal(k1, (d_in, d_out))
b = random.normal(k2, (d_out,))
return W, b
key, *keys = random.split(rng_key, len(layers))
params = list(map(init_layer, keys, layers[:-1], layers[1:]))
return params
def apply(params, inputs):
for W, b in params[:-1]:
outputs = np.dot(inputs, W) + b
inputs = sigmoid(outputs)
W, b = params[-1]
outputs = np.dot(inputs, W) + b
return outputs
return init, apply
class FlaxNet(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x_in):
x = nn.Dense(self.features[0], use_bias=False)(x_in)
x = sigmoid(x)
for feat in self.features[1:-1]:
x = nn.Dense(feat, use_bias=False)(x)
x = sigmoid(x)
x = nn.Dense(self.features[-1], use_bias=False)(x)
return x
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
D = np.pi
layers = [1, 64, 64, 64, 32, 4]
net_init, net_apply = MLP(layers)
params = net_init(random.PRNGKey(0))
inputs = jax.random.uniform(rng, minval=-D, maxval=D, shape=(128, 1))
_ = net_apply(params, inputs)
inputs = jax.random.uniform(rng, minval=-D, maxval=D, shape=(128, 1))
t1 = time.time()
outputs = net_apply(params, inputs)
print('TIME JAX ', time.time()-t1)
#############################################################################
model = FlaxNet(features=[64, 64, 64, 32, 4])
params = model.init(rng, inputs)
_ = model.apply(params, inputs)
t1 = time.time()
outputs = model.apply(params, inputs)
print('TIME FLAX ', time.time()-t1)
Что производит вывод:
TIME JAX 0.0033071041107177734
TIME FLAX 0.08791708946228027
1 ответ
Вам просто нужно опустить строки дополнений :)
import time
import jax.numpy as jnp
from jax import random
from jax.nn import sigmoid
import jax
from flax import linen as nn
from typing import Sequence
def MLP(layers):
def init(rng_key):
def init_layer(key, d_in, d_out):
k1, k2 = random.split(key)
W = random.normal(k1, (d_in, d_out))
b = random.normal(k2, (d_out,))
return W, b
key, *keys = random.split(rng_key, len(layers))
params = list(map(init_layer, keys, layers[:-1], layers[1:]))
return params
def apply(params, inputs):
for W, b in params[:-1]:
outputs = jnp.dot(inputs, W) + b
inputs = sigmoid(outputs)
W, b = params[-1]
outputs = jnp.dot(inputs, W) + b
return outputs
return init, apply
class FlaxNet(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x_in):
x = nn.Dense(self.features[0], use_bias=False)(x_in)
x = sigmoid(x)
for feat in self.features[1:-1]:
x = nn.Dense(feat, use_bias=False)(x)
x = sigmoid(x)
x = nn.Dense(self.features[-1], use_bias=False)(x)
return x
D = jnp.pi
layers = [1, 64, 64, 64, 32, 4]
net_init, net_apply = MLP(layers)
params = net_init(random.PRNGKey(0))
inputs = jax.random.uniform(random.PRNGKey(1), minval=-D, maxval=D,
shape= (128, 1))
t1 = time.time()
outputs = net_apply(params, inputs)
print('TIME JAX ', time.time() - t1)
model = FlaxNet(features=[64, 64, 64, 32, 4])
params = model.init(random.PRNGKey(0), inputs)
t1 = time.time()
_ = model.apply(params, inputs)
print('TIME FLAX ', time.time() - t1)
Новые времена:
TIME JAX 0.854097843170166
TIME FLAX 0.04825115203857422