From db7189bbea87a042898a392b5e5cd2a518d7a385 Mon Sep 17 00:00:00 2001 From: Risi Kondor Date: Thu, 12 Oct 2023 02:44:51 -0500 Subject: [PATCH] 'changes' --- .../SO3/functions/SO3part_addCGproductFn.hpp | 1 - objects/SO3c/SO3partC.hpp | 44 +++++++++++++++---- objects/SO3c/SO3partSpec.hpp | 8 ++++ objects/SO3c/tests/testSO3part.cpp | 16 +++---- 4 files changed, 50 insertions(+), 19 deletions(-) diff --git a/objects/SO3/functions/SO3part_addCGproductFn.hpp b/objects/SO3/functions/SO3part_addCGproductFn.hpp index f9d8b40..bdda21e 100644 --- a/objects/SO3/functions/SO3part_addCGproductFn.hpp +++ b/objects/SO3/functions/SO3part_addCGproductFn.hpp @@ -40,7 +40,6 @@ namespace GElib{ void operator()(const SO3part3_view& _r, const SO3part3_view& _x, const SO3part3_view& _y, const int _offs=0){ - const int l=_r.getl(); const int l1=_x.getl(); const int l2=_y.getl(); diff --git a/objects/SO3c/SO3partC.hpp b/objects/SO3c/SO3partC.hpp index 610c91e..c3d036e 100644 --- a/objects/SO3c/SO3partC.hpp +++ b/objects/SO3c/SO3partC.hpp @@ -13,20 +13,23 @@ #define _GElibSO3partC #include "GElib_base.hpp" -#include "LtensorView.hpp" +#include "Ltensor.hpp" #include "SO3partSpec.hpp" #include "diff_class.hpp" #include "WorkStreamLoop.hpp" +#include "SO3part_addCGproductFn.hpp" + + namespace GElib{ template - class SO3part: public cnine::LtensorView, + class SO3part: public cnine::Ltensor >, public cnine::diff_class >{ public: - typedef cnine::LtensorView BASE; + typedef cnine::Ltensor > BASE; typedef cnine::diff_class > diff_class; typedef cnine::Gdims Gdims; @@ -70,6 +73,19 @@ namespace GElib{ static SO3partSpec sequential() {return SO3partSpec().sequential();} static SO3partSpec gaussian() {return SO3partSpec().gaussian();} + SO3partSpec spec() const{ + return BASE::spec(); + } + + + public: // ---- Constructors ------------------------------------------------------------------------------ + + + cnine::Ctensor3_view view3() const{ + if(is_batched()) return cnine::TensorView >::view3(); + else return unsqueeze0(cnine::TensorView >::view2()); + } + public: // ---- Access ------------------------------------------------------------------------------------- @@ -124,16 +140,26 @@ namespace GElib{ }; - - /* + template - inline SO3part CGproduct(const BASE& x, const BASE& y, const int l){ - assert(l>=abs(x.getl()-y.getl()) && l<=x.getl()+y.getl()); - SO3part R=SO3part::zero(x.getb(),l,x.getn()*y.getn(),x.device()); - R.add_CGproduct(x,y); + inline SO3part operator*(const SO3part& x, const cnine::Ltensor >& y){ + CNINE_ASSRT(y.ndims()==2); + CNINE_ASSRT(y.dim(0)==x.dims(-1)); + SO3part R(x.spec().channels(y.dim(1))); + R.add_mprod(x,y); return R; + } + + + template + inline SO3part CGproduct(const SO3part& x, const SO3part& y, const int l){ + assert(l>=abs(x.getl()-y.getl()) && l<=x.getl()+y.getl()); + SO3part r=SO3part::zero().l(l).n(x.getn()*y.getn()).dev(x.dev); + SO3part_addCGproductFn()(r.view3(),x.view3(),y.view3()); + return r; } + /* template inline SO3part DiagCGproduct(const BASE& x, const BASE& y, const int l){ assert(x.getn()==y.getn()); diff --git a/objects/SO3c/SO3partSpec.hpp b/objects/SO3c/SO3partSpec.hpp index a169a40..2f2cfc2 100644 --- a/objects/SO3c/SO3partSpec.hpp +++ b/objects/SO3c/SO3partSpec.hpp @@ -26,6 +26,7 @@ namespace GElib{ public: typedef cnine::TensorSpecBase > BASE; + using BASE::BASE; using BASE::ddims; @@ -36,11 +37,18 @@ namespace GElib{ SO3partSpec(const BASE& x): BASE(x){} + SO3partSpec(const cnine::TensorSpec > x): + BASE(reinterpret_cast(x)){} + SO3part operator()(){ return SO3part(*this); } + public: // ---- Copying ----------------------------------- + + + public: // ---- Construction ------------------------------ diff --git a/objects/SO3c/tests/testSO3part.cpp b/objects/SO3c/tests/testSO3part.cpp index faf61cd..a0aa346 100644 --- a/objects/SO3c/tests/testSO3part.cpp +++ b/objects/SO3c/tests/testSO3part.cpp @@ -1,6 +1,7 @@ #include "GElib_base.cpp" #include "GElibSession.hpp" #include "SO3partC.hpp" +#include "Ltensor.hpp" using namespace cnine; using namespace GElib; @@ -13,21 +14,18 @@ int main(int argc, char** argv){ int l=2; int n=2; - SO3part u=SO3part::zero().l(l).n(n); + SO3part u=SO3part::gaussian().l(l).n(n); SO3part v=SO3part::gaussian().batch(b).l(l).n(n); cout< > M=Tensor >::gaussian({5,5}); + Ltensor > M=Ltensor >::gaussian().dims({2,3}); + cout< w=CGproduct(u,v,2); - //cout< w=CGproduct(u,u,2); + cout< v2=SO3part::gaussian(l,n); - //cout<