followed instruction in
https://developer.apple.com/metal/jax/
I got
Successfully installed importlib-metadata-7.1.0 jax-0.4.28 jax-metal-0.0.7 jaxlib-0.4.28 opt-einsum-3.3.0 scipy-1.13.0 six-1.16.0 zipp-3.18.2
but the test failed
python -c 'import jax; print(jax.numpy.arange(10))'
Traceback (most recent call last):
File "", line 1, in
File "/Users/erivas/jax-metal/lib/python3.9/site-packages/jax/init.py", line 37, in
import jax.core as _core
File "/Users/erivas/jax-metal/lib/python3.9/site-packages/jax/core.py", line 18, in
from jax._src.core import (
File "/Users/erivas/jax-metal/lib/python3.9/site-packages/jax/_src/core.py", line 39, in
from jax._src import dtypes
File "/Users/erivas/jax-metal/lib/python3.9/site-packages/jax/_src/dtypes.py", line 33, in
from jax._src import config
File "/Users/erivas/jax-metal/lib/python3.9/site-packages/jax/_src/config.py", line 27, in
from jax._src import lib
File "/Users/erivas/jax-metal/lib/python3.9/site-packages/jax/_src/lib/init.py", line 84, in
cpu_feature_guard.check_cpu_features()
RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.
tensorflow-metal
RSS for tagTensorFlow accelerates machine learning model training with Metal on Mac GPUs.
Posts under tensorflow-metal tag
118 Posts
Sort by:
Post
Replies
Boosts
Views
Activity
Cannot assign a device for operation encoder/down1/downs_0/conv1/weight/Initializer/random_uniform/RandomUniform: Could not satisfy explicit device specification '' because the node {{colocation_node encoder/down1/downs_0/conv1/weight/Initializer/random_uniform/RandomUniform}} was colocated with a group of nodes that required incompatible device '/device:GPU:0'. All available devices [/job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0].
Colocation Debug Info:
Colocation group had the following types and supported devices:
Root Member(assigned_device_name_index_=-1 requested_device_name_='/device:GPU:0' assigned_device_name_='' resource_device_name_='/device:GPU:0' supported_device_types_=[CPU] possible_devices_=[]
Identity: GPU CPU
Mul: GPU CPU
AddV2: GPU CPU
Sub: GPU CPU
RandomUniform: GPU CPU
Assign: CPU
VariableV2: GPU CPU
Const: GPU CPU
Regardless of the installation version combinations of tensorflow & metal (2.14, 2.15, 2.16), I find a metal/non-metal incompatibility for some layer types. For the GRU layer, for example, metal-trained weights (model.save_weights()/load_weights()) are not compatible with inference using the CPU. That is, train a model using metal, run inference using metal, then run inference again after uninstalling metal, and the results differ -- sometimes a night and day difference. This essentially eliminates the usefulness of tensorflow-metal for me. From my limited testing, models using other, simple combinations of layer types including Dense and LSTM do not show this problem. Just the GRU. And by "testing" I mean really simple models, like one GRU layer. Apple Framework Metal Team: You are doing very useful work, and I kindly ask, please address this bug :)
I noticed from the system requirements, TensorFlow only seems to support Python. Are there any plans to add JavaScript as TensorFlow has JS support?
Thank you for your time...
I using a Macbook pro with an m2 pro chip. I was trying to work with TensorFlow but I encountered an illegal hardware instruction error. To resolve it I initiated the installation of a metal plugin which is throwing the following error.
or semicolon (after version specifier)
awscli>=1.16.100boto3>=1.9.100
~~~~~~~~~~~^
Unable to locate awscli
[end of output]
When fitting a CNN model, every second Epoch takes zero seconds and with OUT_OF_RANGE warnings. Im using structured folders of categorical images for training and validation. Here is the warning message that occurs after every second Epoch.
The fitting looks like this...
37/37 ━━━━━━━━━━━━━━━━━━━━ 14s 337ms/step - accuracy: 0.5255 - loss: 1.0819 - val_accuracy: 0.2578 - val_loss: 2.4472
Epoch 4/20
37/37 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - accuracy: 0.5312 - loss: 1.1106 - val_accuracy: 0.1250 - val_loss: 3.0711
Epoch 5/20
2024-04-19 09:22:51.673909: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[[{{node IteratorGetNext}}]]
2024-04-19 09:22:51.673928: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[[{{node IteratorGetNext}}]]
[[IteratorGetNext/_59]]
2024-04-19 09:22:51.673940: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 10431687783238222105
2024-04-19 09:22:51.673944: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 17360824274615977051
2024-04-19 09:22:51.673955: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 10732905483452597729
My setup is..
Tensor Flow Version: 2.16.1
Python 3.9.19 (main, Mar 21 2024, 12:07:41)
[Clang 14.0.6 ]
Pandas 2.2.2 Scikit-Learn 1.4.2 GPU is available
My generator is..
train_generator = datagen.flow_from_directory(
scalp_dir_train, # directory
target_size=(256, 256),# all images found will be resized
batch_size=32,
class_mode='categorical'
#subset='training' # Specify the subset as training
)
n_samples = train_generator.samples # gets the number of samples
validation_generator = datagen.flow_from_directory(
scalp_dir_test, # directory path
target_size=(256, 256),
batch_size=32,
class_mode='categorical'
#subset='validation' # Specifying the subset as validation
Here is my model.
early_stopping_monitor = EarlyStopping(patience = 10,restore_best_weights=True)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers import SGD
optimizer = Adam(learning_rate=0.01)
model = Sequential()
model.add(Conv2D(128, (3, 3), activation='relu',padding='same', input_shape=(256, 256, 3)))
model.add(BatchNormalization())
model.add(MaxPooling2D((2, 2)))
model.add(Dropout(0.3))
model.add(Conv2D(64, (3, 3),padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D((2, 2)))
model.add(Dropout(0.3))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.4))
model.add(Dense(256, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.3))
model.add(Dense(4, activation='softmax')) # Defined by the number of classes
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
Here is the fit...
history=model.fit(
train_generator,
steps_per_epoch=37,
epochs=20,
validation_data=validation_generator,
validation_steps=12,
callbacks=[early_stopping_monitor]
#verbose=2
)
Hi,
I just noticed that using the jax.numpy.insert() function returns an incorrect result (zero-padding the array) when compiled with jax.jit. When not jitted, the results are correct
Config:
M1 Pro Macbook Pro 2021
python 3.12.3 ; jax-metal 0.0.6 ; jax 0.4.26 ; jaxlib 0.4.23
MWE:
import jax
import jax.numpy as jnp
x = jnp.arange(20).reshape(5, 4)
print(f"{x=}\n")
def return_arr_with_ins(arr, ins):
return jnp.insert(arr, 2, ins, axis=1)
x2 = return_arr_with_ins(x, 99)
print(f"{x2=}\n")
return_arr_with_ins_jit = jax.jit(return_arr_with_ins)
x3 = return_arr_with_ins_jit(x, 99)
print(f"{x3=}\n")
Output: x2 (computed with the non-jitted function) is correct; x3 just has zero-padding instead of a column of 99
x=Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19]], dtype=int32)
x2=Array([[ 0, 1, 99, 2, 3],
[ 4, 5, 99, 6, 7],
[ 8, 9, 99, 10, 11],
[12, 13, 99, 14, 15],
[16, 17, 99, 18, 19]], dtype=int32)
x3=Array([[ 0, 1, 2, 3, 0],
[ 4, 5, 6, 7, 0],
[ 8, 9, 10, 11, 0],
[12, 13, 14, 15, 0],
[16, 17, 18, 19, 0]], dtype=int32)
The same code run on a non-metal machine gives the correct results:
x=Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19]], dtype=int32)
x2=Array([[ 0, 1, 99, 2, 3],
[ 4, 5, 99, 6, 7],
[ 8, 9, 99, 10, 11],
[12, 13, 99, 14, 15],
[16, 17, 99, 18, 19]], dtype=int32)
x3=Array([[ 0, 1, 99, 2, 3],
[ 4, 5, 99, 6, 7],
[ 8, 9, 99, 10, 11],
[12, 13, 99, 14, 15],
[16, 17, 99, 18, 19]], dtype=int32)
Not sure if this is the correct channel for bug reports, please feel free to let me know if there's a more appropriate place!
Hello,
We all face issues with the latest tensorflow gpu. Incorrect result, errors etc... We all agreed to pay extra for the M1/2/3 so we could work on a professional grade computer but in the end we must use CPU. When will apple actually comment on that and provide updates. I totally understand these issues aren't fixed overnight and take some time, but i've never seen any apple dev answer saying that they understand and they're working on a fix.
I've basically bought a Mac M3 Pro to be able to run on GPU some stuff without having to purchase a server and it's now useless. It's really frustrating.
(Copied from https://github.com/google/jax/issues/20835)
I am attempting to use JAX on Metal (on a M1 Pro chip) to model discrete (count) data. I've installed the latest version jax-metal 0.0.6 using pip.
The installation seems to have worked overall as I can perform basic Jax array operations on GPU. However, when I try to compute the (log-)PMFs/PDFs of random variables which are defined in terms of the (log-)Gamma function I get errors like the one below which seems to indicate that the lax.lgamma function is not supported under the hood on M1 metal.
This is essential functionality for a wide class of probabilistic machine learning models. Note that following functions (among others) are broken as a result:
jax.scipy.stats.binom.logpmf
jax.scipy.stats.nbinom.logpmf
jax.scipy.stats.poisson.logpmf
jax.scipy.stats.dirichlet.logpdf
jax.scipy.stats.beta.logpdf
jax.scipy.stats.gamma.logpdf
...
>>> jax.scipy.stats.binom.logpmf(1, n=2, p=0.5)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/stats/binom.py", line 31, in logpmf
gammaln(n + 1),
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/special.py", line 44, in gammaln
return lax.lgamma(x)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/lax/special.py", line 46, in lgamma
return lgamma_p.bind(x)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 422, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 425, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 913, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'chlo.lgamma'
<stdin>:1:0: note: see current operation: %0 = "chlo.lgamma"(%arg0) : (tensor<f32>) -> tensor<f32>
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.23
numpy: 1.26.4
python: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='PHS027794', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:10:42 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6000', machine='arm64')
Copying from https://github.com/google/jax/issues/20750:
import jax
import jax.numpy as jnp
def test_func(x, y):
return x, y
def main():
# Print available JAX devices
print("JAX devices:", jax.devices())
# Create two random matrices
a = jnp.array([[1.0, 2.0], [3.0, 4.0]])
b = jnp.array([[5.0, 6.0], [7.0, 8.0]])
# Perform matrix multiplication
c = jnp.dot(a, b)
# Print the result
print("Result of matrix multiplication:")
print(c)
# Compute the gradient of sum of c with respect to a
grad_a = jax.grad(lambda a: jnp.sum(jnp.dot(a, b)))(a)
print("Gradient with respect to a:")
print(grad_a)
rng = jax.random.PRNGKey(0)
test_input = jax.random.normal(key=rng, shape=(5,5,5))
initial_state = jax.numpy.array(0.0)
x, y = jax.lax.scan(test_func, initial_state, test_input)
if __name__ == "__main__":
main()
Gets:
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-04-15 18:22:28.994752: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M2 Pro
systemMemory: 16.00 GB
maxCacheSize: 5.33 GB
JAX devices: [METAL(id=0)]
Result of matrix multiplication:
[[19. 22.]
[43. 50.]]
Gradient with respect to a:
[[11. 15.]
[11. 15.]]
zsh: segmentation fault python JAXTest.py
With more info from the debugger:
Current thread 0x00000001fdd3bac0 (most recent call first):
File "/Users/.../anaconda3/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1213 in __call__
My configuration is:
jax-metal : 0.0.6
jax: 0.4.26
jaxlib: 0.4.23
numpy: 1.24.3
python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:49:36) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', root:xnu-10063.101.17~1/RELEASE_ARM64_T6020', machine='arm64')
macOS 14.4.1 (23E224)
Before in 3.9+0.0.3 etc it wasn't happening.
Hi,
I have encountered to a segfault error when I called something via jax.lax.scan.
A minimum failing example is pasted below:
$ ipython
Python 3.9.6 (default, Feb 3 2024, 15:58:27)
Type 'copyright', 'credits' or 'license' for more information
IPython 8.18.1 -- An enhanced Interactive Python. Type '?' for help.
In [1]: import jax
In [2]: jax.__version__
Out[2]: '0.4.22'
In [3]: import jaxlib
In [4]: jaxlib.__version__
Out[4]: '0.4.22'
In [6]: import jax.numpy as jnp
In [7]: def f(carry, x):
...: return carry + x * x, x * x
...:
...: jax.lax.scan(f, jnp.zeros((), dtype=jnp.float32), jnp.arange(3, dtype=jnp.float32))
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-04-16 01:03:52.483015: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max
systemMemory: 36.00 GB
maxCacheSize: 13.50 GB
zsh: segmentation fault ipython
This might be related to the thread below:
https://developer.apple.com/forums/thread/749080
Strangely, when we call it
jax.lax.scan is a very important building block, so I would greatly appreciate if this can be resolved soon.
Hardware: 16" 2023 MBP M3 Pro
OS: 14.4.1
Memory: 36 GB
python version: 3.8.16
TF-Metal version: tensorflow-metal 1.0.1 installed via pip
TF version: 2.13.0
Tensorflow-Metal starts pretty slow, approximately 10s/iteration and over the course of 36 iteration progressively slows down to over 120s/iteration. Info log prints out that TFLite is using XNNPack. Can't share the TFLite model but it is relatively shallow, small, and simple.
Uninstalled TF-Metal, and installed tensorflow. Inference speed picks right up and is rock solid at 0.78s/iteration. What is going on???
**TLDR, TFLite inference speed:
TF Metal = 120s/iteration
TF = 0.78s/iteration**
Tried various how-tos on youtube and github. Have conda.
Third step fails.
conda install -c apple tensorflow-deps
pip install tensorflow-macos
pip install tensorflow-metal
ERROR: Could not find a version that satisfies the requirement tensorflow-metal (from versions: none)
ERROR: No matching distribution found for tensorflow-metal
I see a lot of fixes for Intel-based Mac. None for M3. HELP!?
I tried running inference with the 2B model from https://github.com/google-deepmind/gemma on my M2 MacBook Pro, but it segfaults during sampling: https://pastebin.com/KECyz60T
Note: out of the box it will try to load bfloat16 weights, which will fail. To avoid this, I patched line 30 in gemma/params.py to explicitly cast to float32:
param_state = jax.tree_util.tree_map(lambda p: jnp.array(p, jnp.float32), params)
Tensorflow metal was working on my Power Mac Mac M3 until yesterday. Then my code started freezing. I ran the test script from https://developer.apple.com/metal/tensorflow-plugin/ and it now crashes - this used to work fine, but all of a sudden it does not. The results are shown below. Has anyone seen anything like this? Could this be a hardware problem?
MacBook-Pro-3: carl$ python mac_tensorflow_test.py
Epoch 1/5
1/782 [..............................] - ETA: 51:53 - loss: 6.0044 - accuracy: 0.0312Error: command buffer exited with error status.
The Metal Performance Shaders operations encoded on it may not have completed.
Error:
(null)
Ignored (for causing prior/excessive GPU errors) (00000004:kIOGPUCommandBufferCallbackErrorSubmissionsIgnored)
<AGXG15XFamilyCommandBuffer: 0x1172515e0>
label = <none>
device = <AGXG15SDevice: 0x1588e6000>
name = Apple M3 Pro
commandQueue = <AGXG15XFamilyCommandQueue: 0x17427e400>
label = <none>
device = <AGXG15SDevice: 0x1588e6000>
name = Apple M3 Pro
retainedReferences = 1
Error: command buffer exited with error status.
The Metal Performance Shaders operations encoded on it may not have completed.
Error:
(null)
Ignored (for causing prior/excessive GPU errors) (00000004:kIOGPUCommandBufferCallbackErrorSubmissionsIgnored)
<AGXG15XFamilyCommandBuffer: 0x117257b40>
label = <none>
device = <AGXG15SDevice: 0x1588e6000>
name = Apple M3 Pro
commandQueue = <AGXG15XFamilyCommandQueue: 0x17427e400>
label = <none>
device = <AGXG15SDevice: 0x1588e6000>
name = Apple M3 Pro
retainedReferences = 1
Many more rows of similar printouts follow.
Hi i am trying to set up tensorflow-metal as instructed by https://developer.apple.com/metal/tensorflow-plugin/
when running line (python -m pip install tensorflow-metal) I get the following error:
ERROR: Could not find a version that satisfies the requirement tensorflow-metal (from versions: none)
ERROR: No matching distribution found for tensorflow-metal
According to the troubleshooting section: "Check that the Python version used in the environment is supported (Python 3.8, Python 3.9, Python 3.10)." My current version is Python 3.9.12.
Any insight would be great!
Hi,
I am looking for a routine to perform complex-valued linear algebra on the GPU in python for scientific programming, in particular quantum physics simulations.
At the moment I am looking for a routine for complex-valued matrix multiplication. I found MLX has a routine for float matrix multiplication, but it does not directly work for complex-valued matrices. I figured a work-around by splitting the complex valued matrix into real and imaginary part and working with the pair, but it makes it cumbersome to integrate with the remainder of the code. I was hoping for a library-based implementation similar to cupy.
I also tried out using the tensorflow linear algebra routines, but I couldn't get them to run on the GPU by now. Specifically, a testfile with a tensorflow.keras.applications.ResNet50 routine runs on the GPU, but the routines from tensorflow.linalg and tensorflow.math that I tested (matmul, expm, eigh) were not running on the GPU.
Any advice on how to make linear algebra calculations on mac GPUs work is highly appreciated! For my application the unified memory might be especially beneficial.
Thank you!
InvalidArgumentError: Cannot assign a device for operation don_nn/model_2/branch_hidden0/MatMul/ReadVariableOp: Could not satisfy explicit device specification '' because the node {{colocation_node don_nn/model_2/branch_hidden0/MatMul/ReadVariableOp}} was colocated with a group of nodes that required incompatible device '/job:localhost/replica:0/task:0/device:GPU:0'. All available devices [/job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0].
Problem
I am trying to use the jax.numpy.einsum function (https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.einsum.html). However, for some subscripts, this seems to fail.
Hardware
Apple M1 Max, 32GB RAM
Steps to Reproduce
follow installation steps from https://developer.apple.com/metal/jax/
conda create -n 'jax_metal_demo' python=3.11
conda activate jax_metal_demo
python -m pip install numpy wheel ml-dtypes==0.2.0
python -m pip install jax-metal
Save the following code in a file called minimal_example.py
import numpy as np
from jax import device_put
import jax.numpy as jnp
np.random.seed(0)
a = np.random.rand(11, 12, 13, 11, 12)
b = np.random.rand(11, 12, 13)
subscripts = 'ijklm,ijk->lmk'
# intended result
print(np.einsum(subscripts, a, b))
# will cause crash
a, b = device_put(a), device_put(b)
print(jnp.einsum(subscripts, a, b))
run the code
python minimal_example.py
Output
I waas expecting
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-02-12 16:45:34.684973: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max
systemMemory: 32.00 GB
maxCacheSize: 10.67 GB
Traceback (most recent call last):
File "/Users/linus/workspace/minimal_example.py", line 15, in <module>
print(jnp.einsum(subscripts, a, b))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/linus/miniforge3/envs/jax_metal_demo/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 3369, in einsum
return _einsum_computation(operands, contractions, precision, # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/linus/miniforge3/envs/jax_metal_demo/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/linus/workspace/minimal_example.py:15:6: error: failed to legalize operation 'mhlo.dot_general'
print(jnp.einsum(subscripts, a, b))
^
/Users/linus/workspace/minimal_example.py:15:6: note: see current operation: %0 = "mhlo.dot_general"(%arg1, %arg0) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [2], rhs_batching_dimensions = [2], lhs_contracting_dimensions = [0, 1], rhs_contracting_dimensions = [0, 1]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<11x12x13xf32>, tensor<11x12x13x11x12xf32>) -> tensor<13x11x12xf32>
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Conclusion
I would greatly appreciate any ideas for workarounds.
macbook pro m2 max/ 64G / macos:13.2.1 (22D68)
import tensorflow as tf
def runMnist(device = '/device:CPU:0'):
with tf.device(device):
#tf.config.set_default_device(device)
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam',
loss=loss_fn,
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10)
runMnist(device = '/device:CPU:0')
runMnist(device = '/device:GPU:0')