Installation

Installation#

First create a virtualenv using pyenv or conda.

Install using pip:

pip install tn4ml

Install directly from the git repository:

pip install git+https://github.com/bsc-quantic/tn4ml.git

Install by cloning the repository and navigate to the root directory of the repository and run:

pip install .

Or install the package in development mode:

pip install -e .

For tests, install the package with the test dependencies:

pip install .[test]

Run the tests:

pytest

Accelerated runtime

(Optional) To improve runtime precision set these flags: .. code-block:: python

import jax jax.config.update(“jax_enable_x64”, True) jax.config.update(‘jax_default_matmul_precision’, ‘highest’)

Running on GPU Before everything install JAX version that supports CUDA and its suitable for runs on GPU. Checkout how to install here: jax[cuda].

Next, at the beginning of your script set:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use GPU 0 - or set any GPU ID
import jax
jax.config.update("jax_platform_name", 'gpu')

Then when training Model set: .. code-block:: python

device = ‘gpu’ model.configure(device=device)