Skip to content

Commit

Permalink
see if removing einsum fixes nan issues
Browse files Browse the repository at this point in the history
  • Loading branch information
phoebe-p committed Aug 22, 2024
1 parent 0795730 commit 6da41b2
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions rayflare/matrix_formalism/multiply_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Contact: p.pearce@unsw.edu.au

import numpy as np
from sparse import load_npz, COO, stack, einsum
from sparse import load_npz, COO, stack, einsum, dot
from rayflare.angles import make_angle_vector, fold_phi, overall_bin
import os
import xarray as xr
Expand Down Expand Up @@ -162,17 +162,31 @@ def make_D(alphas, thick, thetas):
# (GitHub: arsonwong)
def dot_wl(mat, vec):

# if len(mat.shape) == 3:
# result = einsum('ijk,ik->ij', mat, COO(vec)).todense()
#
# if len(mat.shape) == 2:
# result = einsum('jk,ik->ij', mat, COO(vec)).todense()
result = np.empty((vec.shape[0], mat.shape[1]))

if len(mat.shape) == 3:
result = einsum('ijk,ik->ij', mat, COO(vec)).todense()
for i1 in range(vec.shape[0]): # loop over wavelengths
result[i1, :] = dot(mat[i1], vec[i1])

if len(mat.shape) == 2:
result = einsum('jk,ik->ij', mat, COO(vec)).todense()
for i1 in range(vec.shape[0]): # loop over wavelengths
result[i1, :] = dot(mat, vec[i1])


return result


def dot_wl_u2d(mat, vec):
result = einsum('jk,ik->ij', mat, COO(vec)).todense()
# result = einsum('jk,ik->ij', mat, COO(vec)).todense()
result = np.empty((vec.shape[0], vec.shape[1]))
for i1 in range(vec.shape[0]): # loop over wavelengths
result[i1, :] = dot(mat, vec[i1])

return result


Expand Down

0 comments on commit 6da41b2

Please sign in to comment.