diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index cc577f55..82028207 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -73,6 +73,8 @@ from .context import ( ArrayContext, ArrayContextFactory, + CSRMatrix, + SparseMatrix, tag_axes, ) from .impl.jax import EagerJAXArrayContext @@ -129,6 +131,7 @@ "ArrayOrScalarT", "ArrayT", "BcastUntilActxArray", + "CSRMatrix", "CommonSubexpressionTag", "ContainerOrScalarT", "EagerJAXArrayContext", @@ -144,6 +147,7 @@ "ScalarLike", "SerializationKey", "SerializedContainer", + "SparseMatrix", "dataclass_array_container", "deserialize_container", "flat_size_and_dtype", diff --git a/arraycontext/context.py b/arraycontext/context.py index e0b2b907..83ed9a1a 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -78,6 +78,9 @@ .. autoclass:: ArrayContext +.. autoclass:: SparseMatrix +.. autoclass:: CSRMatrix + .. autofunction:: tag_axes .. class:: P @@ -114,6 +117,7 @@ """ +import dataclasses from abc import ABC, abstractmethod from collections.abc import Callable, Hashable, Mapping from typing import ( @@ -121,6 +125,7 @@ Any, ParamSpec, TypeAlias, + cast, overload, ) from warnings import warn @@ -131,6 +136,9 @@ # FIXME: remove sometime, this import was used in grudge in July 2025. from .typing import ArrayOrArithContainerTc as ArrayOrArithContainerTc +from arraycontext.container.traversal import ( + rec_map_container, +) if TYPE_CHECKING: @@ -138,15 +146,17 @@ from numpy.typing import DTypeLike import loopy - from pytools.tag import ToTagSetConvertible + from pytools.tag import Tag, ToTagSetConvertible from .fake_numpy import BaseFakeNumpyNamespace from .typing import ( Array, ArrayContainerT, ArrayOrArithContainerOrScalarT, + ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, + ArrayOrScalar, ContainerOrScalarT, NumpyOrContainerOrScalar, ScalarLike, @@ -155,6 +165,26 @@ P = ParamSpec("P") +_EMPTY_TAG_SET: frozenset[Tag] = frozenset() + + +@dataclasses.dataclass(frozen=True, eq=False, repr=False) +class SparseMatrix(ABC): + shape: tuple[int, int] + tags: ToTagSetConvertible = dataclasses.field(kw_only=True) + axes: tuple[ToTagSetConvertible, ...] = dataclasses.field(kw_only=True) + _actx: ArrayContext = dataclasses.field(kw_only=True) + + def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer: + return self._actx.sparse_matmul(self, other) + + +@dataclasses.dataclass(frozen=True, eq=False, repr=False) +class CSRMatrix(SparseMatrix): + elem_values: Array + elem_col_indices: Array + row_starts: Array + # {{{ ArrayContext @@ -172,6 +202,8 @@ class ArrayContext(ABC): .. automethod:: to_numpy .. automethod:: call_loopy .. automethod:: einsum + .. automethod:: make_csr_matrix + .. automethod:: sparse_matmul .. attribute:: np Provides access to a namespace that serves as a work-alike to @@ -424,6 +456,178 @@ def einsum(self, )["out"] return self.tag(tagged, out_ary) + def make_csr_matrix( + self, + shape: tuple[int, int], + elem_values: Array, + elem_col_indices: Array, + row_starts: Array, + *, + tags: ToTagSetConvertible = _EMPTY_TAG_SET, + axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix: + """Return a sparse matrix in compressed sparse row (CSR) format, to be used + with :meth:`sparse_matmul`. + + :arg shape: the (two-dimensional) shape of the matrix + :arg elem_values: a one-dimensional array containing the values of all of the + nonzero entries of the matrix, grouped by row. + :arg elem_col_indices: a one-dimensional array containing the column index + values corresponding to each entry in *elem_values*. + :arg row_starts: a one-dimensional array of length `nrows+1`, where each entry + gives the starting index in *elem_values* and *elem_col_indices* for the + given row, with the last entry being equal to `nrows`. + """ + if axes is None: + axes = (frozenset(), frozenset()) + + return CSRMatrix( + shape, elem_values, elem_col_indices, row_starts, + tags=tags, axes=axes, + _actx=self) + + @memoize_method + def _get_csr_matmul_prg(self, out_ndim: int) -> loopy.TranslationUnit: + import numpy as np + + import loopy as lp + + out_extra_inames = tuple(f"i{n}" for n in range(1, out_ndim)) + out_inames = ("irow", *out_extra_inames) + out_inames_set = frozenset(out_inames) + + out_extra_shape_comp_names = tuple(f"n{n}" for n in range(1, out_ndim)) + out_shape_comp_names = ("nrows", *out_extra_shape_comp_names) + + domains: list[str] = [] + domains.append( + "{ [" + ",".join(out_inames) + "] : " + + " and ".join( + f"0 <= {iname} < {shape_comp_name}" + for iname, shape_comp_name in zip( + out_inames, out_shape_comp_names, strict=True)) + + " }") + domains.append( + "{ [iel] : iel_lbound <= iel < iel_ubound }") + + temporary_variables: Mapping[str, lp.TemporaryVariable] = { + "iel_lbound": lp.TemporaryVariable( + "iel_lbound", + shape=(), + address_space=lp.AddressSpace.GLOBAL, + # FIXME: Need to do anything with tags? + ), + "iel_ubound": lp.TemporaryVariable( + "iel_ubound", + shape=(), + address_space=lp.AddressSpace.GLOBAL, + # FIXME: Need to do anything with tags? + )} + + from loopy.kernel.instruction import make_assignment + from pymbolic import var + # FIXME: Need tags for any of these? + instructions: list[lp.Assignment | lp.CallInstruction] = [ + make_assignment( + (var("iel_lbound"),), + var("row_starts")[var("irow")], + id="insn0", + within_inames=out_inames_set), + make_assignment( + (var("iel_ubound"),), + var("row_starts")[var("irow") + 1], + id="insn1", + within_inames=out_inames_set), + make_assignment( + (var("out")[tuple(var(iname) for iname in out_inames)],), + lp.Reduction( + "sum", + (var("iel"),), + var("elem_values")[var("iel"),] + * var("array")[( + var("elem_col_indices")[var("iel"),], + *(var(iname) for iname in out_extra_inames))]), + id="insn2", + within_inames=out_inames_set, + depends_on=frozenset({"insn0", "insn1"}))] + + from loopy.version import MOST_RECENT_LANGUAGE_VERSION + + from .loopy import _DEFAULT_LOOPY_OPTIONS + + knl = lp.make_kernel( + domains=domains, + instructions=instructions, + temporary_variables=temporary_variables, + kernel_data=[ + lp.ValueArg("nrows", is_input=True), + lp.ValueArg("ncols", is_input=True), + lp.ValueArg("nels", is_input=True), + *( + lp.ValueArg(shape_comp_name, is_input=True) + for shape_comp_name in out_extra_shape_comp_names), + lp.GlobalArg("elem_values", shape=(var("nels"),), is_input=True), + lp.GlobalArg("elem_col_indices", shape=(var("nels"),), is_input=True), + lp.GlobalArg("row_starts", shape=lp.auto, is_input=True), + lp.GlobalArg( + "array", + shape=( + var("ncols"), + *( + var(shape_comp_name) + for shape_comp_name in out_extra_shape_comp_names),), + is_input=True), + lp.GlobalArg( + "out", + shape=( + var("nrows"), + *( + var(shape_comp_name) + for shape_comp_name in out_extra_shape_comp_names),), + is_input=False), + ...], + name="csr_matmul_kernel", + lang_version=MOST_RECENT_LANGUAGE_VERSION, + options=_DEFAULT_LOOPY_OPTIONS, + default_order=lp.auto, + default_offset=lp.auto, + # FIXME: Need to do anything with tags? + ) + + idx_dtype = knl.default_entrypoint.index_dtype + + return lp.add_and_infer_dtypes( + knl, + { + ",".join([ + "ncols", "nrows", "nels", + *out_extra_shape_comp_names]): idx_dtype, + "elem_values,array,out": np.float64, + "elem_col_indices,row_starts": idx_dtype}) + + def sparse_matmul( + self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer: + """Multiply a sparse matrix by an array. + + :arg x1: the sparse matrix. + :arg x2: the array. + """ + if isinstance(x1, CSRMatrix): + def _matmul(ary: ArrayOrScalar) -> ArrayOrScalar: + assert self.is_array_type(ary) + prg = self._get_csr_matmul_prg(len(ary.shape)) + out_ary = self.call_loopy( + prg, elem_values=x1.elem_values, + elem_col_indices=x1.elem_col_indices, + row_starts=x1.row_starts, array=ary)["out"] + # FIXME + # return self.tag(tagged, out_ary) + return out_ary + + return cast("ArrayOrContainer", rec_map_container(_matmul, x2)) + + else: + raise TypeError(f"unrecognized sparse matrix type '{type(x1).__name__}'") + @abstractmethod def clone(self) -> Self: """If possible, return a version of *self* that is semantically diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index 73f29cf5..e9230ca1 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -33,14 +33,16 @@ from typing import TYPE_CHECKING, cast import numpy as np +from typing_extensions import override from arraycontext.container.traversal import ( rec_map_container, with_array_context, ) -from arraycontext.context import ArrayContext +from arraycontext.context import ArrayContext, CSRMatrix, SparseMatrix from arraycontext.typing import ( Array, + ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrScalar, ScalarLike, @@ -51,7 +53,10 @@ if TYPE_CHECKING: from collections.abc import Callable - from pytools.tag import ToTagSetConvertible + from pytools.tag import Tag, ToTagSetConvertible + + +_EMPTY_TAG_SET: frozenset[Tag] = frozenset() class EagerJAXArrayContext(ArrayContext): @@ -150,6 +155,23 @@ def einsum(self, spec, *args, arg_names=None, tagged=()): import jax.numpy as jnp return jnp.einsum(spec, *args) + @override + def make_csr_matrix( + self, + shape: tuple[int, int], + elem_values: Array, + elem_col_indices: Array, + row_starts: Array, + *, + tags: ToTagSetConvertible = _EMPTY_TAG_SET, + axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix: + raise NotImplementedError("Sparse matrices aren't yet supported with JAX.") + + @override + def sparse_matmul( + self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer: + raise NotImplementedError("Sparse matrices aren't yet supported with JAX.") + def clone(self): return type(self)() diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 265d7a53..810cbfb5 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -47,12 +47,16 @@ ) from arraycontext.context import ( ArrayContext, + CSRMatrix, + SparseMatrix, UntransformedCodeWarning, ) from arraycontext.typing import ( Array, + ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, + ArrayOrScalar, ContainerOrScalarT, NumpyOrContainerOrScalar, is_scalar_like, @@ -61,11 +65,14 @@ if TYPE_CHECKING: from pymbolic import Scalar - from pytools.tag import ToTagSetConvertible + from pytools.tag import Tag, ToTagSetConvertible from arraycontext.typing import ArrayContainerT +_EMPTY_TAG_SET: frozenset[Tag] = frozenset() + + class NumpyNonObjectArrayMetaclass(type): @override def __instancecheck__(cls, instance: object) -> bool: @@ -199,6 +206,32 @@ def tag_axis(self, def einsum(self, spec, *args, arg_names=None, tagged=()): return np.einsum(spec, *args, optimize="optimal") + @override + def sparse_matmul( + self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer: + if isinstance(x1, CSRMatrix): + assert isinstance(x1.elem_values, np.ndarray) + assert isinstance(x1.elem_col_indices, np.ndarray) + assert isinstance(x1.row_starts, np.ndarray) + + # FIXME: Not sure if the scipy dependency is OK or if it should just use + # the call_loopy fallback? Currently getting errors with the loopy version: + # loopy.diagnostic.LoopyError: One of the kernels in the program has + # been preprocessed, cannot modify target now. + from scipy.sparse import csr_matrix + np_matrix = csr_matrix( + (x1.elem_values, x1.elem_col_indices, x1.row_starts), + shape=x1.shape) + + def _matmul(ary: ArrayOrScalar) -> ArrayOrScalar: + assert isinstance(ary, np.ndarray) + return np_matrix @ ary + + return cast("ArrayOrContainer", rec_map_container(_matmul, x2)) + + else: + raise TypeError(f"unrecognized sparse matrix type '{type(x1).__name__}'") + @property def permits_inplace_modification(self): return True diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 6e883f55..6a0d569e 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -59,7 +59,7 @@ import numpy as np from typing_extensions import override -from pytools import memoize_method +from pytools import memoize_in, memoize_method from pytools.tag import Tag, ToTagSetConvertible, normalize_tags from arraycontext.container.traversal import ( @@ -68,13 +68,16 @@ ) from arraycontext.context import ( ArrayContext, + CSRMatrix, P, + SparseMatrix, UntransformedCodeWarning, ) from arraycontext.metadata import NameHint from arraycontext.typing import ( Array, ArrayOrArithContainerOrScalarT, + ArrayOrContainer, ArrayOrContainerOrScalarT, ArrayOrScalar, ScalarLike, @@ -833,6 +836,7 @@ def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays dag = pt.transform.materialize_with_mpms(dag) return dag + @override def einsum(self, spec, *args, arg_names=None, tagged=()): import pytato as pt @@ -876,6 +880,32 @@ def preprocess_arg(name, arg): for name, arg in zip(arg_names, args, strict=True) ]).tagged(_preprocess_array_tags(tagged)) + @override + def sparse_matmul( + self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer: + import pytato as pt + + if isinstance(x1, CSRMatrix): + @memoize_in(x1, "pt_matrix") + def _get_pt_matrix() -> pt.CSRMatrix: + assert isinstance(x1.elem_values, pt.Array) + assert isinstance(x1.elem_col_indices, pt.Array) + assert isinstance(x1.row_starts, pt.Array) + return pt.make_csr_matrix( + x1.shape, x1.elem_values, x1.elem_col_indices, x1.row_starts, + tags=_preprocess_array_tags(x1.tags), axes=x1.axes) + + pt_matrix: pt.CSRMatrix = _get_pt_matrix() + + def _matmul(ary: ArrayOrScalar) -> ArrayOrScalar: + assert isinstance(ary, pt.Array) + return pt_matrix @ ary + + return cast("ArrayOrContainer", rec_map_container(_matmul, x2)) + + else: + raise TypeError(f"unrecognized sparse matrix type '{type(x1).__name__}'") + def clone(self): return type(self)(self.queue, self.allocator) @@ -1115,6 +1145,23 @@ def preprocess_arg(name: str | None, arg: Array): for name, arg in zip(arg_names, args, strict=True) ]).tagged(_preprocess_array_tags(tagged))) + @override + def make_csr_matrix( + self, + shape: tuple[int, int], + elem_values: Array, + elem_col_indices: Array, + row_starts: Array, + *, + tags: ToTagSetConvertible = _EMPTY_TAG_SET, + axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix: + raise NotImplementedError("Sparse matrices aren't yet supported with JAX.") + + @override + def sparse_matmul( + self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer: + raise NotImplementedError("Sparse matrices aren't yet supported with JAX.") + @override def clone(self): return type(self)() diff --git a/requirements.txt b/requirements.txt index a4cb4025..54a2a5ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy git+https://github.com/inducer/loopy.git#egg=loopy -git+https://github.com/inducer/pytato.git#egg=pytato +git+https://github.com/majosm/pytato.git@sparse-matrix#egg=pytato diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 6d7a38a4..bae0f12e 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -37,10 +37,12 @@ from arraycontext import ( ArrayContextFactory, + ArrayOrScalar, BcastUntilActxArray, EagerJAXArrayContext, NumpyArrayContext, PyOpenCLArrayContext, + PytatoJAXArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, pytest_generate_tests_for_array_contexts, @@ -656,6 +658,71 @@ def test_array_context_einsum_array_tripleprod(actx_factory: ArrayContextFactory # }}} +def test_array_context_csr_matmul(actx_factory: ArrayContextFactory): + actx = actx_factory() + + if isinstance(actx, (EagerJAXArrayContext, PytatoJAXArrayContext)): + pytest.skip(f"not implemented for '{type(actx).__name__}'") + + n = 100 + + x = actx.from_numpy(np.arange(n, dtype=np.float64)) + ary_of_x = obj_array.new_1d([x] * 3) + dc_of_x = MyContainer( + name="container", + mass=x, + momentum=obj_array.new_1d([x] * 3), + enthalpy=x) + + elem_values = actx.zeros((n//2,), dtype=np.float64) + 1. + elem_col_indices = actx.from_numpy(2*np.arange(n//2, dtype=np.int32)) + row_starts = actx.from_numpy(np.arange(n//2 + 1, dtype=np.int32)) + + mat = actx.make_csr_matrix( + shape=(n//2, n), + elem_values=elem_values, + elem_col_indices=elem_col_indices, + row_starts=row_starts) + + expected_mat_x = actx.from_numpy(2 * np.arange(n//2, dtype=np.float64)) + + def _check_allclose( + arg1: ArrayOrScalar, arg2: ArrayOrScalar, atol: float = 1.0e-14): + from arraycontext import NotAnArrayContainerError + try: + arg1_iterable = serialize_container(arg1) + arg2_iterable = serialize_container(arg2) + except NotAnArrayContainerError: + assert np.linalg.norm(actx.to_numpy(arg1 - arg2)) < atol + else: + arg1_subarrays = [ + subarray for _, subarray in arg1_iterable] + arg2_subarrays = [ + subarray for _, subarray in arg2_iterable] + for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays, + strict=True): + _check_allclose(subarray1, subarray2) + + # single array + res = mat @ x + expected_res = expected_mat_x + _check_allclose(res, expected_res) + + # array of arrays + res = mat @ ary_of_x + expected_res = obj_array.new_1d([expected_mat_x] * 3) + _check_allclose(res, expected_res) + + # container of arrays + res = mat @ dc_of_x + expected_res = MyContainer( + name="container", + mass=expected_mat_x, + momentum=obj_array.new_1d([expected_mat_x] * 3), + enthalpy=expected_mat_x) + _check_allclose(res, expected_res) + + # {{{ array container classes for test