Skip to content

Commit

Permalink
'changes'
Browse files Browse the repository at this point in the history
  • Loading branch information
risi-kondor committed Oct 12, 2023
1 parent 22cc67b commit db7189b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 19 deletions.
1 change: 0 additions & 1 deletion objects/SO3/functions/SO3part_addCGproductFn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ namespace GElib{

void operator()(const SO3part3_view& _r, const SO3part3_view& _x, const SO3part3_view& _y, const int _offs=0){


const int l=_r.getl();
const int l1=_x.getl();
const int l2=_y.getl();
Expand Down
44 changes: 35 additions & 9 deletions objects/SO3c/SO3partC.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,23 @@
#define _GElibSO3partC

#include "GElib_base.hpp"
#include "LtensorView.hpp"
#include "Ltensor.hpp"
#include "SO3partSpec.hpp"
#include "diff_class.hpp"
#include "WorkStreamLoop.hpp"

#include "SO3part_addCGproductFn.hpp"


namespace GElib{


template<typename TYPE>
class SO3part: public cnine::LtensorView<TYPE>,
class SO3part: public cnine::Ltensor<complex<TYPE> >,
public cnine::diff_class<SO3part<TYPE> >{
public:

typedef cnine::LtensorView<TYPE> BASE;
typedef cnine::Ltensor<complex<TYPE> > BASE;
typedef cnine::diff_class<SO3part<TYPE> > diff_class;

typedef cnine::Gdims Gdims;
Expand Down Expand Up @@ -70,6 +73,19 @@ namespace GElib{
static SO3partSpec<TYPE> sequential() {return SO3partSpec<TYPE>().sequential();}
static SO3partSpec<TYPE> gaussian() {return SO3partSpec<TYPE>().gaussian();}

SO3partSpec<TYPE> spec() const{
return BASE::spec();
}


public: // ---- Constructors ------------------------------------------------------------------------------


cnine::Ctensor3_view view3() const{
if(is_batched()) return cnine::TensorView<complex<TYPE> >::view3();
else return unsqueeze0(cnine::TensorView<complex<TYPE> >::view2());
}


public: // ---- Access -------------------------------------------------------------------------------------

Expand Down Expand Up @@ -124,16 +140,26 @@ namespace GElib{

};


/*

template<typename TYPE>
inline SO3part<TYPE> CGproduct(const BASE& x, const BASE& y, const int l){
assert(l>=abs(x.getl()-y.getl()) && l<=x.getl()+y.getl());
SO3part<TYPE> R=SO3part<TYPE>::zero(x.getb(),l,x.getn()*y.getn(),x.device());
R.add_CGproduct(x,y);
inline SO3part<TYPE> operator*(const SO3part<TYPE>& x, const cnine::Ltensor<complex<TYPE> >& y){
CNINE_ASSRT(y.ndims()==2);
CNINE_ASSRT(y.dim(0)==x.dims(-1));
SO3part<TYPE> R(x.spec().channels(y.dim(1)));
R.add_mprod(x,y);
return R;
}


template<typename TYPE>
inline SO3part<TYPE> CGproduct(const SO3part<TYPE>& x, const SO3part<TYPE>& y, const int l){
assert(l>=abs(x.getl()-y.getl()) && l<=x.getl()+y.getl());
SO3part<TYPE> r=SO3part<TYPE>::zero().l(l).n(x.getn()*y.getn()).dev(x.dev);
SO3part_addCGproductFn()(r.view3(),x.view3(),y.view3());
return r;
}

/*
template<typename TYPE>
inline SO3part<TYPE> DiagCGproduct(const BASE& x, const BASE& y, const int l){
assert(x.getn()==y.getn());
Expand Down
8 changes: 8 additions & 0 deletions objects/SO3c/SO3partSpec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace GElib{
public:

typedef cnine::TensorSpecBase<SO3partSpec<TYPE> > BASE;
using BASE::BASE;

using BASE::ddims;

Expand All @@ -36,11 +37,18 @@ namespace GElib{
SO3partSpec(const BASE& x):
BASE(x){}

SO3partSpec(const cnine::TensorSpec<complex<TYPE> > x):
BASE(reinterpret_cast<const BASE&>(x)){}

SO3part<TYPE> operator()(){
return SO3part<TYPE>(*this);
}


public: // ---- Copying -----------------------------------



public: // ---- Construction ------------------------------


Expand Down
16 changes: 7 additions & 9 deletions objects/SO3c/tests/testSO3part.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "GElib_base.cpp"
#include "GElibSession.hpp"
#include "SO3partC.hpp"
#include "Ltensor.hpp"

using namespace cnine;
using namespace GElib;
Expand All @@ -13,21 +14,18 @@ int main(int argc, char** argv){
int l=2;
int n=2;

SO3part<float> u=SO3part<float>::zero().l(l).n(n);
SO3part<float> u=SO3part<float>::gaussian().l(l).n(n);
SO3part<float> v=SO3part<float>::gaussian().batch(b).l(l).n(n);
cout<<u<<endl;
cout<<v<<endl;

//Tensor<complex<float> > M=Tensor<complex<float> >::gaussian({5,5});
Ltensor<complex<float> > M=Ltensor<complex<float> >::gaussian().dims({2,3});
cout<<M<<endl;
//cout<<M*u<<endl;
//cout<<u*M<<endl;
cout<<u*M<<endl;

//SO3part<float> w=CGproduct(u,v,2);
//cout<<w<<endl;
SO3part<float> w=CGproduct(u,u,2);
cout<<w<<endl;

//SO3part<float> v2=SO3part<float>::gaussian(l,n);
//cout<<CGproduct(u,v2,2)<<endl;

//cout<<DiagCGproduct(u,v,2)<<endl;

}

0 comments on commit db7189b

Please sign in to comment.