From 6da41b203e3ac841571da51217fe009797c17243 Mon Sep 17 00:00:00 2001
From: Phoebe Pearce
Date: Fri, 23 Aug 2024 09:12:10 +1000
Subject: [PATCH] see if removing einsum fixes nan issues
---
.../matrix_formalism/multiply_matrices.py | 22 +++++++++++++++----
1 file changed, 18 insertions(+), 4 deletions(-)
diff --git a/rayflare/matrix_formalism/multiply_matrices.py b/rayflare/matrix_formalism/multiply_matrices.py
index 4f65ecc..3b4e196 100644
--- a/rayflare/matrix_formalism/multiply_matrices.py
+++ b/rayflare/matrix_formalism/multiply_matrices.py
@@ -6,7 +6,7 @@
# Contact: p.pearce@unsw.edu.au
import numpy as np
-from sparse import load_npz, COO, stack, einsum
+from sparse import load_npz, COO, stack, einsum, dot
from rayflare.angles import make_angle_vector, fold_phi, overall_bin
import os
import xarray as xr
@@ -162,17 +162,31 @@ def make_D(alphas, thick, thetas):
# (GitHub: arsonwong)
def dot_wl(mat, vec):
+ # if len(mat.shape) == 3:
+ # result = einsum('ijk,ik->ij', mat, COO(vec)).todense()
+ #
+ # if len(mat.shape) == 2:
+ # result = einsum('jk,ik->ij', mat, COO(vec)).todense()
+ result = np.empty((vec.shape[0], mat.shape[1]))
+
if len(mat.shape) == 3:
- result = einsum('ijk,ik->ij', mat, COO(vec)).todense()
+ for i1 in range(vec.shape[0]): # loop over wavelengths
+ result[i1, :] = dot(mat[i1], vec[i1])
if len(mat.shape) == 2:
- result = einsum('jk,ik->ij', mat, COO(vec)).todense()
+ for i1 in range(vec.shape[0]): # loop over wavelengths
+ result[i1, :] = dot(mat, vec[i1])
+
return result
def dot_wl_u2d(mat, vec):
- result = einsum('jk,ik->ij', mat, COO(vec)).todense()
+ # result = einsum('jk,ik->ij', mat, COO(vec)).todense()
+ result = np.empty((vec.shape[0], vec.shape[1]))
+ for i1 in range(vec.shape[0]): # loop over wavelengths
+ result[i1, :] = dot(mat, vec[i1])
+
return result