diff --git a/rayflare/matrix_formalism/multiply_matrices.py b/rayflare/matrix_formalism/multiply_matrices.py index 4f65ecc..3b4e196 100644 --- a/rayflare/matrix_formalism/multiply_matrices.py +++ b/rayflare/matrix_formalism/multiply_matrices.py @@ -6,7 +6,7 @@ # Contact: p.pearce@unsw.edu.au import numpy as np -from sparse import load_npz, COO, stack, einsum +from sparse import load_npz, COO, stack, einsum, dot from rayflare.angles import make_angle_vector, fold_phi, overall_bin import os import xarray as xr @@ -162,17 +162,31 @@ def make_D(alphas, thick, thetas): # (GitHub: arsonwong) def dot_wl(mat, vec): + # if len(mat.shape) == 3: + # result = einsum('ijk,ik->ij', mat, COO(vec)).todense() + # + # if len(mat.shape) == 2: + # result = einsum('jk,ik->ij', mat, COO(vec)).todense() + result = np.empty((vec.shape[0], mat.shape[1])) + if len(mat.shape) == 3: - result = einsum('ijk,ik->ij', mat, COO(vec)).todense() + for i1 in range(vec.shape[0]): # loop over wavelengths + result[i1, :] = dot(mat[i1], vec[i1]) if len(mat.shape) == 2: - result = einsum('jk,ik->ij', mat, COO(vec)).todense() + for i1 in range(vec.shape[0]): # loop over wavelengths + result[i1, :] = dot(mat, vec[i1]) + return result def dot_wl_u2d(mat, vec): - result = einsum('jk,ik->ij', mat, COO(vec)).todense() + # result = einsum('jk,ik->ij', mat, COO(vec)).todense() + result = np.empty((vec.shape[0], vec.shape[1])) + for i1 in range(vec.shape[0]): # loop over wavelengths + result[i1, :] = dot(mat, vec[i1]) + return result